diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 45118c6..1b63f0e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,6 @@ jobs: shell: pwsh run: | choco install ninja -y - choco install opencppcoverage -y New-Item -ItemType Directory -Force "$env:GITHUB_WORKSPACE\vcpkg_cache" | Out-Null git clone https://github.com/microsoft/vcpkg "$env:GITHUB_WORKSPACE\vcpkg" & "$env:GITHUB_WORKSPACE\vcpkg\bootstrap-vcpkg.bat" @@ -48,6 +47,15 @@ jobs: zlib ` --triplet x64-windows + - name: Install OpenCppCoverage (Windows) + if: runner.os == 'Windows' + shell: pwsh + run: | + $installer = "$env:TEMP\OpenCppCoverageSetup.exe" + Invoke-WebRequest -Uri "https://github.com/OpenCppCoverage/OpenCppCoverage/releases/download/release-0.9.9.0/OpenCppCoverageSetup-x64-0.9.9.0.exe" -OutFile $installer + Start-Process -FilePath $installer -ArgumentList "/VERYSILENT" -Wait + echo "C:\Program Files\OpenCppCoverage" >> $env:GITHUB_PATH + - name: Setup MSVC environment (Windows) if: runner.os == 'Windows' uses: ilammy/msvc-dev-cmd@v1 @@ -89,7 +97,7 @@ jobs: if: runner.os == 'Windows' shell: pwsh run: > - & "$env:ProgramFiles\OpenCppCoverage\OpenCppCoverage.exe" + & "C:\Program Files\OpenCppCoverage\OpenCppCoverage.exe" --export_type cobertura:coverage.xml --sources ${{ github.workspace }} -- ctest --test-dir build --output-on-failure -C Release diff --git a/.gitignore b/.gitignore index 66a6152..2490d42 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,6 @@ docs/api/ test_package/build test_package/CMakeUserPresets.json test-data/* -bench/results* \ No newline at end of file +bench/results* +**build-tests +*.jsonl diff --git a/README.md b/README.md index 3a2fb42..9766891 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,49 @@ # TRX-cpp + +[![Documentation](https://readthedocs.org/projects/trx-cpp/badge/?version=latest)](https://trx-cpp.readthedocs.io/en/latest/) [![codecov](https://codecov.io/gh/tee-ar-ex/trx-cpp/branch/main/graph/badge.svg)](https://codecov.io/gh/tee-ar-ex/trx-cpp) -TRX-cpp is a C++11 library for reading, writing, and memory-mapping the TRX -tractography file format (zip archives or on-disk directories of memmaps). +A C++17 library for reading, writing, and memory-mapping the [TRX tractography format](https://github.com/tee-ar-ex/trx-spec) — efficient storage for large-scale tractography data. -## Documentation +## Features + +- **Zero-copy memory mapping** — positions, DPV, and DPS arrays are exposed as `Eigen::Map` views directly over memory-mapped files; no unnecessary copies for large tractograms +- **Streaming writes** — `TrxStream` appends streamlines one at a time and finalizes to a standard TRX archive or directory, suitable when the total count is unknown at the start +- **Spatial queries** — build per-streamline axis-aligned bounding boxes (AABBs) and efficiently extract spatial subsets; designed for interactive slice-view workflows +- **Typed and type-erased APIs** — `TrxFile
` gives compile-time type safety; `AnyTrxFile` dispatches at runtime when the dtype is read from disk +- **ZIP and directory storage** — read and write `.trx` zip archives and plain on-disk directories with the same API +- **Optional NIfTI support** — read qform/sform affines from `.nii` / `.nii.gz` and embed them in the TRX header + +## Quick start + +**Dependencies:** a C++17 compiler, [libzip](https://libzip.org/), [Eigen 3.4+](https://eigen.tuxfamily.org/) + +```cmake +# CMakeLists.txt +find_package(trx-cpp CONFIG REQUIRED) +target_link_libraries(my_app PRIVATE trx-cpp::trx) +``` + +```cpp +#include -Project documentation (build/usage instructions and API reference) is hosted at -https://trx-cpp.readthedocs.io/en/latest/. +// Load any TRX file — dtype detected at runtime +auto trx = trx::load_any("tracks.trx"); + +std::cout << trx.num_streamlines() << " streamlines, " + << trx.num_vertices() << " vertices\n"; + +// Access positions as an Eigen matrix (zero-copy) +auto positions = trx.positions.as_matrix(); // (NB_VERTICES, 3) + +trx.close(); +``` + +See [Building](https://trx-cpp.readthedocs.io/en/latest/building.html) for platform-specific dependency installation and [Quick Start](https://trx-cpp.readthedocs.io/en/latest/quick_start.html) for a complete first program. + +## Documentation +Full documentation is at **[trx-cpp.readthedocs.io](https://trx-cpp.readthedocs.io/en/latest/)**. ## Third-party notices diff --git a/bench/bench_trx_realdata.cpp b/bench/bench_trx_realdata.cpp index 035916d..bee558d 100644 --- a/bench/bench_trx_realdata.cpp +++ b/bench/bench_trx_realdata.cpp @@ -30,11 +30,16 @@ #include #include #endif +#if defined(__APPLE__) +#include +#endif namespace { using Eigen::half; std::string g_reference_trx_path; +bool g_reference_has_dpv = false; +size_t g_reference_streamline_count = 0; constexpr float kMinLengthMm = 20.0f; constexpr float kMaxLengthMm = 500.0f; @@ -122,6 +127,34 @@ double get_max_rss_kb() { #endif } +// Returns the current RSS (not the process-wide peak) so each benchmark can +// track its own peak rather than inheriting the all-time process high-water mark. +double get_current_rss_kb() { +#if defined(__APPLE__) + struct mach_task_basic_info info; + mach_msg_type_number_t count = MACH_TASK_BASIC_INFO_COUNT; + if (task_info(mach_task_self(), MACH_TASK_BASIC_INFO, + reinterpret_cast(&info), &count) != KERN_SUCCESS) { + return 0.0; + } + return static_cast(info.resident_size) / 1024.0; +#elif defined(__linux__) + std::ifstream status("/proc/self/status"); + std::string line; + while (std::getline(status, line)) { + if (line.rfind("VmRSS:", 0) == 0) { + const auto pos = line.find_first_of("0123456789"); + if (pos != std::string::npos) { + return static_cast(std::stoull(line.substr(pos))); + } + } + } + return 0.0; +#else + return 0.0; +#endif +} + size_t parse_env_size(const char *name, size_t default_value) { const char *raw = std::getenv(name); if (!raw || raw[0] == '\0') { @@ -221,7 +254,7 @@ std::vector streamlines_for_benchmarks() { return {only}; } const size_t max_val = parse_env_size("TRX_BENCH_MAX_STREAMLINES", 10000000); - std::vector counts = {10000000, 5000000, 1000000, 500000, 100000}; + std::vector counts = {100000, 500000, 1000000, 5000000, 10000000}; counts.erase(std::remove_if(counts.begin(), counts.end(), [&](size_t v) { return v > max_val; }), counts.end()); if (counts.empty()) { counts.push_back(max_val); @@ -365,7 +398,11 @@ std::unique_ptr> build_prefix_subset_trx(size_t streamlines, } if (add_dpv) { - std::vector dpv(out->num_vertices(), 0.5f); + const size_t n_verts = out->num_vertices(); + std::vector dpv(n_verts); + std::mt19937 rng(12345); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (auto &v : dpv) v = dist(rng); out->add_dpv_from_vector("dpv_random", "float32", dpv); } else { out->data_per_vertex.clear(); @@ -554,79 +591,229 @@ struct TrxOnDisk { size_t shard_processes = 1; }; +// Parse the per-element column count and byte size encoded in a TRX filename. +// TRX convention: .. or . +// e.g. "positions.3.float16" -> {3, 2}, "sift_weights.float32" -> {1, 4} +static std::pair parse_trx_array_dims(const std::string &filename) { + std::vector parts; + std::istringstream ss(filename); + std::string tok; + while (std::getline(ss, tok, '.')) { + if (!tok.empty()) parts.push_back(tok); + } + if (parts.size() < 2) return {1, 4}; + const std::string dtype_str = parts.back(); + const size_t elem_size = static_cast(trx::detail::_sizeof_dtype(dtype_str)); + if (parts.size() >= 3) { + const std::string &maybe_ncols = parts[parts.size() - 2]; + if (!maybe_ncols.empty() && + std::all_of(maybe_ncols.begin(), maybe_ncols.end(), [](unsigned char c) { return std::isdigit(c); })) { + return {static_cast(std::stoul(maybe_ncols)), elem_size}; + } + } + return {1, elem_size}; +} + +#if defined(__unix__) || defined(__APPLE__) +static void truncate_file_to(const std::string &path, off_t byte_size) { + if (::truncate(path.c_str(), byte_size) != 0) { + throw std::runtime_error("truncate " + path + ": " + std::strerror(errno)); + } +} +#endif + +// Truncate every regular file in dir to row_count rows based on the per-file dtype/ncols. +static void truncate_array_dir(const std::string &dir_path, size_t row_count) { + std::error_code ec; + if (!trx::fs::exists(dir_path, ec)) return; + for (const auto &entry : trx::fs::directory_iterator(dir_path, ec)) { + if (ec || !entry.is_regular_file()) continue; + const auto [ncols, elem_size] = parse_trx_array_dims(entry.path().filename().string()); + truncate_file_to(entry.path().string(), static_cast(row_count * ncols * elem_size)); + } +} + +// Write random float32 DPV data directly to temp_dir/dpv/dpv_random.float32 in chunks, +// avoiding a full n_vertices allocation for large datasets. +static void write_synthetic_dpv_to_dir(const std::string &temp_dir, size_t n_vertices) { + const std::string dpv_dir = temp_dir + trx::SEPARATOR + "dpv"; + std::error_code ec; + std::filesystem::create_directories(dpv_dir, ec); + const std::string dpv_path = dpv_dir + trx::SEPARATOR + "dpv_random.float32"; + std::ofstream f(dpv_path, std::ios::binary | std::ios::trunc); + if (!f.is_open()) { + throw std::runtime_error("Cannot open DPV output file: " + dpv_path); + } + constexpr size_t kChunkSize = 1024ULL * 1024ULL; // 1M floats = 4 MB per chunk + std::vector chunk(kChunkSize); + std::mt19937 rng(12345); + std::uniform_real_distribution dist(-1.0f, 1.0f); + size_t remaining = n_vertices; + while (remaining > 0) { + const size_t to_write = std::min(kChunkSize, remaining); + for (size_t i = 0; i < to_write; ++i) { + chunk[i] = dist(rng); + } + f.write(reinterpret_cast(chunk.data()), + static_cast(to_write * sizeof(float))); + remaining -= to_write; + } +} + TrxOnDisk build_trx_file_on_disk_single(size_t streamlines, GroupScenario scenario, bool add_dps, bool add_dpv, zip_uint32_t compression, const std::string &out_path_override = "") { - const size_t progress_every = parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 0); - - // Fast path: For 10M (full reference) WITHOUT DPV, just copy and add groups - // DPV requires 1.3B vertices × 8 bytes (vector + internal copy) = 10 GB extra memory! if (g_reference_trx_path.empty()) { throw std::runtime_error("Reference TRX path not set."); } - auto ref_trx_check = trx::load(g_reference_trx_path); - const size_t ref_count = ref_trx_check->num_streamlines(); - const bool is_full_reference = (streamlines == ref_count); - ref_trx_check.reset(); // Release mmap - - // Fast path for full reference WITHOUT DPV: work with reference directly - // DPV + 10M requires 40-50 GB due to: mmap(7.6) + dpv_vec(4.8) + dpv_mmap(4.8) + save() overhead(20+) - if (is_full_reference && !add_dpv) { - log_bench_start("build_trx_copy_fast", "streamlines=" + std::to_string(streamlines)); - - // Filesystem copy reference (disk I/O only) - const std::string temp_copy = make_temp_path("trx_ref_copy"); - std::error_code copy_ec; - std::filesystem::copy_file(g_reference_trx_path, temp_copy, - std::filesystem::copy_options::overwrite_existing, copy_ec); - if (copy_ec) { - throw std::runtime_error("Failed to copy reference: " + copy_ec.message()); + const bool is_full_reference = (streamlines == g_reference_streamline_count); + + // Fast path: load the reference (extracts to a fresh unique temp dir each call), optionally + // ftruncate the positions/offsets/dps files to a prefix boundary in O(1), write synthetic + // DPV if needed (chunked for >1M streamlines, vector for smaller), then assign groups and + // save. Avoids the large intermediate positions write from subset_streamlines(). + { + log_bench_start("build_trx_prefix_fast", "streamlines=" + std::to_string(streamlines)); + + // Load reference into a fresh temp dir; each call gets its own unique extraction. + auto ref = trx::load(g_reference_trx_path); + const std::string temp_dir = ref->_uncompressed_folder_handle; + + // Locate positions and offsets files while the mmap is still open. + std::string pos_path, off_path; + { + std::error_code ec; + for (const auto &entry : trx::fs::directory_iterator(temp_dir, ec)) { + if (ec) break; + const std::string fn = entry.path().filename().string(); + if (fn.rfind("positions", 0) == 0) pos_path = entry.path().string(); + else if (fn.rfind("offsets", 0) == 0) off_path = entry.path().string(); + } + } + if (pos_path.empty() || off_path.empty()) { + throw std::runtime_error("positions/offsets not found in " + temp_dir); + } + + // Read vertex_cutoff directly from the offsets file so the dtype (uint32 vs uint64) + // is always respected, regardless of how the Eigen map interprets the mmap width. + const auto [off_ncols, off_elem] = + parse_trx_array_dims(trx::fs::path(off_path).filename().string()); + size_t vertex_cutoff = 0; + { + std::ifstream ofs(off_path, std::ios::binary); + ofs.seekg(static_cast(streamlines * off_ncols * off_elem)); + if (off_elem == 4) { + uint32_t v = 0; + ofs.read(reinterpret_cast(&v), 4); + vertex_cutoff = static_cast(v); + } else { + uint64_t v = 0; + ofs.read(reinterpret_cast(&v), 8); + vertex_cutoff = static_cast(v); + } + } + + // Release mmaps without deleting temp_dir — we own it from here. + ref->_owns_uncompressed_folder = false; + ref.reset(); + + if (!is_full_reference) { + // Truncate positions and offsets to the prefix boundary. + const auto [pos_ncols, pos_elem] = + parse_trx_array_dims(trx::fs::path(pos_path).filename().string()); + truncate_file_to(pos_path, static_cast(vertex_cutoff * pos_ncols * pos_elem)); + truncate_file_to(off_path, static_cast((streamlines + 1) * off_ncols * off_elem)); + + // Truncate DPS arrays to the prefix streamline count. + truncate_array_dir(temp_dir + trx::SEPARATOR + "dps", streamlines); + + // Handle DPV for the prefix case. + if (add_dpv && g_reference_has_dpv) { + // Reference has DPV — truncate its files to the prefix vertex count. + truncate_array_dir(temp_dir + trx::SEPARATOR + "dpv", vertex_cutoff); + } else { + // Either DPV is not wanted, or reference has none (synthetic will be written below). + std::error_code ec2; + std::filesystem::remove_all(trx::fs::path(temp_dir) / "dpv", ec2); + } + + // Patch NB_STREAMLINES and NB_VERTICES in header.json. + const std::string header_path = temp_dir + trx::SEPARATOR + "header.json"; + { + std::ifstream in(header_path); + std::string raw((std::istreambuf_iterator(in)), {}); + std::string parse_err; + json hdr = json::parse(raw, parse_err); + if (!parse_err.empty()) throw std::runtime_error("header.json parse error: " + parse_err); + hdr = trx::_json_set(hdr, "NB_STREAMLINES", static_cast(streamlines)); + hdr = trx::_json_set(hdr, "NB_VERTICES", static_cast(vertex_cutoff)); + std::ofstream out(header_path, std::ios::trunc); + out << hdr.dump(); + } + } + + // Remove DPV for the full-reference case when not wanted. + if (is_full_reference && !add_dpv) { + std::error_code ec; + std::filesystem::remove_all(trx::fs::path(temp_dir) / "dpv", ec); + } + + // Strip DPS if not wanted (applies to both full and prefix). + if (!add_dps) { + std::error_code ec; + std::filesystem::remove_all(trx::fs::path(temp_dir) / "dps", ec); + } + + // Clear any groups the reference may carry; the benchmark owns grouping. + { + std::error_code ec; + std::filesystem::remove_all(trx::fs::path(temp_dir) / "groups", ec); + } + + // Write synthetic DPV in chunks before loading when the reference lacks it and + // the vertex count is large (avoids allocating the full buffer at once). + if (add_dpv && !g_reference_has_dpv && streamlines > 1000000) { + write_synthetic_dpv_to_dir(temp_dir, vertex_cutoff); + } + + auto trx = trx::TrxFile::load_from_directory(temp_dir); + + // Add synthetic DPS if requested but the reference didn't have it. + if (add_dps && trx->data_per_streamline.empty()) { + std::vector ones(streamlines, 1.0f); + trx->add_dps_from_vector("sift_weights", "float32", ones); } - - // Load and modify (just groups, no DPV) - auto trx = trx::load(temp_copy); - - // Add groups + + // Add synthetic DPV via vector for small streamline counts. + if (add_dpv && !g_reference_has_dpv && streamlines <= 1000000) { + std::vector dpv_data(vertex_cutoff); + std::mt19937 rng(12345); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (auto &v : dpv_data) v = dist(rng); + trx->add_dpv_from_vector("dpv_random", "float32", dpv_data); + } + assign_groups_to_trx(*trx, scenario, streamlines); - - // Note: DPV for 10M handled by sampling path (skipped by default due to 40-50 GB memory) - - // Save to output - const std::string out_path = out_path_override.empty() ? make_temp_path("trx_input") : out_path_override; + + const std::string out_path = + out_path_override.empty() ? make_temp_path("trx_input") : out_path_override; trx::TrxSaveOptions save_opts; save_opts.compression_standard = compression; trx->save(out_path, save_opts); const size_t total_vertices = trx->num_vertices(); trx->close(); - - std::filesystem::remove(temp_copy, copy_ec); - + trx::rm_dir(temp_dir); + if (out_path_override.empty()) { register_cleanup(out_path); } - - log_bench_end("build_trx_copy_fast", "streamlines=" + std::to_string(streamlines)); + + log_bench_end("build_trx_prefix_fast", "streamlines=" + std::to_string(streamlines)); return {out_path, streamlines, total_vertices, 0.0, 1}; } - - // Prefix-subset path: copy first N streamlines into a TrxFile and save. - auto trx_subset = build_prefix_subset_trx(streamlines, scenario, add_dps, add_dpv); - const size_t total_vertices = trx_subset->num_vertices(); - const std::string out_path = out_path_override.empty() ? make_temp_path("trx_input") : out_path_override; - trx::TrxSaveOptions save_opts; - save_opts.compression_standard = compression; - trx_subset->save(out_path, save_opts); - trx_subset->close(); - if (progress_every > 0 && (parse_env_bool("TRX_BENCH_CHILD_LOG", false) || parse_env_bool("TRX_BENCH_LOG", false))) { - std::cerr << "[trx-bench] progress build_trx streamlines=" << streamlines << " / " << streamlines << std::endl; - } - if (out_path_override.empty()) { - register_cleanup(out_path); - } - return {out_path, streamlines, total_vertices, 0.0, 1}; } TrxOnDisk build_trx_file_on_disk(size_t streamlines, @@ -730,19 +917,18 @@ static void BM_TrxFileSize_Float16(benchmark::State &state) { const bool add_dpv = state.range(3) != 0; const bool use_zip = state.range(4) != 0; const auto compression = use_zip ? ZIP_CM_DEFLATE : ZIP_CM_STORE; - const size_t skip_zip_at = parse_env_size("TRX_BENCH_SKIP_ZIP_AT", 5000000); - const size_t skip_dpv_at = parse_env_size("TRX_BENCH_SKIP_DPV_AT", 10000000); - const size_t skip_connectome_at = parse_env_size("TRX_BENCH_SKIP_CONNECTOME_AT", 5000000); - if (use_zip && streamlines >= skip_zip_at) { - state.SkipWithMessage("zip compression skipped for large streamlines"); - return; - } - if (add_dpv && streamlines >= skip_dpv_at) { - state.SkipWithMessage("dpv skipped: requires 40-50 GB memory (set TRX_BENCH_SKIP_DPV_AT=0 to force)"); + // Skip gracefully when the reference file doesn't have enough streamlines rather than + // reading past EOF in the fast path. + if (streamlines > g_reference_streamline_count) { + state.SkipWithMessage("skipped: streamlines exceeds reference file count"); return; } - if (scenario == GroupScenario::Connectome && streamlines >= skip_connectome_at) { - state.SkipWithMessage("connectome groups skipped for large streamlines (set TRX_BENCH_SKIP_CONNECTOME_AT=0 to force)"); + const size_t skip_zip_at = parse_env_size("TRX_BENCH_SKIP_ZIP_AT", 5000000); + // ZIP deflate on multi-GB files takes O(minutes) per iteration — skip for large counts. + // File sizes at these scales can be estimated from the linear trend of smaller compressed + // results (ratio is stable across counts for the same data distribution). + if (use_zip && streamlines >= skip_zip_at) { + state.SkipWithMessage("zip skipped: deflating multi-GB files is impractically slow for benchmarking"); return; } log_bench_start("BM_TrxFileSize_Float16", @@ -757,11 +943,14 @@ static void BM_TrxFileSize_Float16(benchmark::State &state) { double total_merge_ms = 0.0; double total_build_ms = 0.0; double total_merge_processes = 0.0; + double max_rss_delta_kb = 0.0; for (auto _ : state) { + const double rss_iter_start = get_current_rss_kb(); const auto start = std::chrono::steady_clock::now(); const auto on_disk = build_trx_file_on_disk(streamlines, scenario, add_dps, add_dpv, compression); const auto end = std::chrono::steady_clock::now(); + max_rss_delta_kb = std::max(max_rss_delta_kb, get_current_rss_kb() - rss_iter_start); const std::chrono::duration elapsed = end - start; total_build_ms += elapsed.count(); total_merge_ms += on_disk.shard_merge_ms; @@ -777,6 +966,7 @@ static void BM_TrxFileSize_Float16(benchmark::State &state) { state.counters["dpv"] = add_dpv ? 1.0 : 0.0; state.counters["compression"] = use_zip ? 1.0 : 0.0; state.counters["positions_dtype"] = 16.0; + state.counters["max_rss_kb"] = max_rss_delta_kb; state.counters["write_ms"] = total_write_ms / static_cast(state.iterations()); state.counters["build_ms"] = total_build_ms / static_cast(state.iterations()); if (total_merge_ms > 0.0) { @@ -784,14 +974,58 @@ static void BM_TrxFileSize_Float16(benchmark::State &state) { state.counters["shard_processes"] = total_merge_processes / static_cast(state.iterations()); } state.counters["file_bytes"] = total_file_bytes / static_cast(state.iterations()); - state.counters["max_rss_kb"] = get_max_rss_kb(); log_bench_end("BM_TrxFileSize_Float16", "streamlines=" + std::to_string(streamlines)); } +// Pack a directory tree into a TRX zip archive using zip_source_file for every +// file so libzip reads in chunks via the OS page cache rather than mapping the +// data into the process address space. Avoids the RSS spike that occurs when +// TrxFile::save() syncs and re-reads large DPS/DPV mmaps before archiving. +static void pack_dir_to_zip(const std::string &src_dir, + const std::string &out_path, + zip_uint32_t compression) { + int errorp = 0; + zip_t *za = zip_open(out_path.c_str(), ZIP_CREATE | ZIP_TRUNCATE, &errorp); + if (!za) { + throw std::runtime_error("pack_dir_to_zip: could not open " + out_path); + } + std::error_code ec; + for (const auto &entry : std::filesystem::recursive_directory_iterator(src_dir, ec)) { + if (ec || entry.is_directory(ec)) { + continue; + } + const std::string abs_path = entry.path().string(); + std::string rel_path = std::filesystem::relative(entry.path(), src_dir, ec).string(); + std::replace(rel_path.begin(), rel_path.end(), '\\', '/'); + zip_source_t *src = zip_source_file(za, abs_path.c_str(), 0, -1); + if (!src) { + zip_close(za); + throw std::runtime_error("pack_dir_to_zip: zip_source_file failed for " + abs_path); + } + const zip_int64_t idx = zip_file_add(za, rel_path.c_str(), src, ZIP_FL_ENC_UTF_8 | ZIP_FL_OVERWRITE); + if (idx < 0) { + zip_source_free(src); + zip_close(za); + throw std::runtime_error("pack_dir_to_zip: zip_file_add failed for " + rel_path); + } + if (zip_set_file_compression(za, idx, static_cast(compression), 0) < 0) { + zip_close(za); + throw std::runtime_error("pack_dir_to_zip: zip_set_file_compression failed"); + } + } + if (zip_close(za) < 0) { + throw std::runtime_error("pack_dir_to_zip: zip_close failed for " + out_path); + } +} + static void BM_TrxStream_TranslateWrite(benchmark::State &state) { const size_t streamlines = static_cast(state.range(0)); + if (streamlines > g_reference_streamline_count) { + state.SkipWithMessage("skipped: streamlines exceeds reference file count"); + return; + } const auto scenario = static_cast(state.range(1)); const bool add_dps = state.range(2) != 0; const bool add_dpv = state.range(3) != 0; @@ -817,26 +1051,83 @@ static void BM_TrxStream_TranslateWrite(benchmark::State &state) { state.counters["shard_merge_ms"] = dataset.shard_merge_ms; state.counters["shard_processes"] = static_cast(dataset.shard_processes); } + double max_rss_delta_kb = 0.0; for (auto _ : state) { + const double rss_iter_start = get_current_rss_kb(); + + // Sample RSS on a background thread throughout the iteration so we capture + // the true peak (positions mmap + chunk tmp vector both resident during + // translation), not just the end-of-iteration value which may be lower + // after mmaps have been released. + std::atomic rss_sampling{true}; + std::atomic peak_rss_kb{static_cast(rss_iter_start)}; + std::thread rss_sampler([&]() { + while (rss_sampling.load(std::memory_order_relaxed)) { + const long s = static_cast(get_current_rss_kb()); + long prev = peak_rss_kb.load(std::memory_order_relaxed); + while (s > prev && + !peak_rss_kb.compare_exchange_weak(prev, s, + std::memory_order_relaxed, std::memory_order_relaxed)) {} + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + const auto start = std::chrono::steady_clock::now(); auto trx = trx::load_any(dataset.path); + const std::string input_dir = trx.uncompressed_folder_handle(); const size_t chunk_bytes = parse_env_size("TRX_BENCH_CHUNK_BYTES", 1024ULL * 1024ULL * 1024ULL); const std::string out_dir = make_work_dir_name("trx_translate_chunk"); - trx::PrepareOutputOptions prep_opts; - prep_opts.overwrite_existing = true; - const auto out_info = trx::prepare_positions_output(trx, out_dir, prep_opts); - std::ofstream out_positions(out_info.positions_path, std::ios::binary | std::ios::out | std::ios::trunc); - if (!out_positions.is_open()) { - throw std::runtime_error("Failed to open output positions file: " + out_info.positions_path); + // Determine the positions filename before we clear anything. + const std::string pos_fname = trx.positions.cols > 1 + ? "positions." + std::to_string(trx.positions.cols) + "." + trx.positions.dtype + : "positions." + trx.positions.dtype; + const size_t total_points = static_cast(trx.positions.rows); + + // Copy all metadata and data arrays from the extracted input directory to + // out_dir using filesystem operations (fcopyfile / sendfile under the hood) + // so the data never passes through the process's mapped address space. + // Skip the positions file — we will write the translated version ourselves. + std::error_code ec; + std::filesystem::create_directories(out_dir, ec); + for (const auto &entry : std::filesystem::directory_iterator(input_dir, ec)) { + if (ec || entry.is_directory(ec)) { + continue; + } + if (entry.path().filename().string() == pos_fname) { + continue; // will be replaced with translated positions + } + std::filesystem::copy_file(entry.path(), + std::filesystem::path(out_dir) / entry.path().filename(), + std::filesystem::copy_options::overwrite_existing, ec); + } + for (const char *sub : {"dps", "dpv", "groups", "dpg"}) { + const auto src = std::filesystem::path(input_dir) / sub; + if (std::filesystem::exists(src, ec)) { + std::filesystem::copy(src, std::filesystem::path(out_dir) / sub, + std::filesystem::copy_options::recursive | + std::filesystem::copy_options::overwrite_existing, + ec); + } } + // Release DPS/DPV mmaps — data is now safely on disk in out_dir. + // Only the positions mmap is needed going forward. + trx.data_per_vertex.clear(); + trx.data_per_streamline.clear(); + + // Write translated positions. + const std::string positions_path = out_dir + trx::SEPARATOR + pos_fname; + std::ofstream out_positions(positions_path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out_positions.is_open()) { + throw std::runtime_error("Failed to open output positions file: " + positions_path); + } trx.for_each_positions_chunk(chunk_bytes, [&](trx::TrxScalarType dtype, const void *data, size_t offset, size_t count) { (void)offset; if (progress_every > 0 && ((offset + count) % progress_every == 0)) { std::cerr << "[trx-bench] progress translate points=" << (offset + count) - << " / " << out_info.points << std::endl; + << " / " << total_points << std::endl; } const size_t total_vals = count * 3; if (dtype == trx::TrxScalarType::Float16) { @@ -868,20 +1159,29 @@ static void BM_TrxStream_TranslateWrite(benchmark::State &state) { out_positions.flush(); out_positions.close(); + // Release the input positions and offsets mmaps — all reading is done. + // Applies the same early-release pattern used for DPS/DPV above: once + // the mmap is cleared the OS can reclaim those pages, so pack_dir_to_zip + // does not run while they are still pinned in resident_size. + trx.positions = trx::TypedArray{}; + trx.offsets = trx::TypedArray{}; + + // Pack out_dir using zip_source_file so libzip reads files in buffered + // chunks rather than mapping them — DPS/DPV stay out of resident_size. const std::string out_path = make_temp_path("trx_translate"); - trx::TrxSaveOptions save_opts; - save_opts.mode = trx::TrxSaveMode::Archive; - save_opts.compression_standard = ZIP_CM_STORE; - save_opts.overwrite_existing = true; - trx.save(out_path, save_opts); + pack_dir_to_zip(out_dir, out_path, ZIP_CM_STORE); trx::rm_dir(out_dir); const auto end = std::chrono::steady_clock::now(); const std::chrono::duration elapsed = end - start; state.SetIterationTime(elapsed.count()); - std::error_code ec; std::filesystem::remove(out_path, ec); benchmark::DoNotOptimize(trx); + + rss_sampling.store(false, std::memory_order_relaxed); + rss_sampler.join(); + const double delta = static_cast(peak_rss_kb.load(std::memory_order_relaxed)) - rss_iter_start; + max_rss_delta_kb = std::max(max_rss_delta_kb, delta); } state.counters["streamlines"] = static_cast(streamlines); @@ -890,7 +1190,7 @@ static void BM_TrxStream_TranslateWrite(benchmark::State &state) { state.counters["dps"] = add_dps ? 1.0 : 0.0; state.counters["dpv"] = add_dpv ? 1.0 : 0.0; state.counters["positions_dtype"] = 16.0; - state.counters["max_rss_kb"] = get_max_rss_kb(); + state.counters["max_rss_kb"] = max_rss_delta_kb; log_bench_end("BM_TrxStream_TranslateWrite", "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1))); @@ -898,13 +1198,20 @@ static void BM_TrxStream_TranslateWrite(benchmark::State &state) { static void BM_TrxQueryAabb_Slabs(benchmark::State &state) { const size_t streamlines = static_cast(state.range(0)); + if (streamlines > g_reference_streamline_count) { + state.SkipWithMessage("skipped: streamlines exceeds reference file count"); + return; + } const auto scenario = static_cast(state.range(1)); const bool add_dps = state.range(2) != 0; const bool add_dpv = state.range(3) != 0; + // 0 means "no limit" (return all intersecting streamlines). + const size_t max_query_streamlines = parse_env_size("TRX_BENCH_MAX_QUERY_STREAMLINES", 500); log_bench_start("BM_TrxQueryAabb_Slabs", "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1)) + " dps=" + std::to_string(static_cast(add_dps)) + - " dpv=" + std::to_string(static_cast(add_dpv))); + " dpv=" + std::to_string(static_cast(add_dpv)) + + " max_query_streamlines=" + std::to_string(max_query_streamlines)); using Key = KeyHash::Key; static std::unordered_map cache; @@ -922,7 +1229,9 @@ static void BM_TrxQueryAabb_Slabs(benchmark::State &state) { } auto &dataset = cache.at(key); + double max_rss_delta_kb = 0.0; for (auto _ : state) { + const double rss_iter_start = get_current_rss_kb(); std::vector slab_times_ms; slab_times_ms.reserve(kSlabCount); @@ -932,7 +1241,11 @@ static void BM_TrxQueryAabb_Slabs(benchmark::State &state) { const auto &min_corner = dataset.slab_mins[i]; const auto &max_corner = dataset.slab_maxs[i]; const auto q_start = std::chrono::steady_clock::now(); - auto subset = dataset.trx->query_aabb(min_corner, max_corner); + auto subset = dataset.trx->query_aabb(min_corner, max_corner, + /*precomputed_aabbs=*/nullptr, + /*build_cache_for_result=*/false, + max_query_streamlines, + /*rng_seed=*/static_cast(i)); const auto q_end = std::chrono::steady_clock::now(); const std::chrono::duration q_elapsed = q_end - q_start; slab_times_ms.push_back(q_elapsed.count()); @@ -952,6 +1265,8 @@ static void BM_TrxQueryAabb_Slabs(benchmark::State &state) { state.counters["query_p50_ms"] = p50; state.counters["query_p95_ms"] = p95; + max_rss_delta_kb = std::max(max_rss_delta_kb, get_current_rss_kb() - rss_iter_start); + ScenarioParams params; params.streamlines = streamlines; params.scenario = scenario; @@ -966,9 +1281,10 @@ static void BM_TrxQueryAabb_Slabs(benchmark::State &state) { state.counters["dps"] = add_dps ? 1.0 : 0.0; state.counters["dpv"] = add_dpv ? 1.0 : 0.0; state.counters["query_count"] = static_cast(kSlabCount); + state.counters["max_query_streamlines"] = static_cast(max_query_streamlines); state.counters["slab_thickness_mm"] = kSlabThicknessMm; state.counters["positions_dtype"] = 16.0; - state.counters["max_rss_kb"] = get_max_rss_kb(); + state.counters["max_rss_kb"] = max_rss_delta_kb; log_bench_end("BM_TrxQueryAabb_Slabs", "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1))); @@ -1002,17 +1318,12 @@ static void ApplySizeArgs(benchmark::internal::Benchmark *bench) { static void ApplyStreamArgs(benchmark::internal::Benchmark *bench) { const std::array flags = {0, 1}; - const bool core_profile = is_core_profile(); - const size_t dpv_limit = core_dpv_max_streamlines(); const auto groups = group_cases_for_benchmarks(); const auto counts_desc = streamlines_for_benchmarks(); for (const auto count : counts_desc) { - const std::vector dpv_flags = (!core_profile || count <= dpv_limit) - ? std::vector{0, 1} - : std::vector{0}; for (const auto group_case : groups) { for (const auto dps : flags) { - for (const auto dpv : dpv_flags) { + for (const auto dpv : flags) { bench->Args({static_cast(count), group_case, dps, dpv}); } } @@ -1104,6 +1415,16 @@ int main(int argc, char **argv) { // Set global reference path g_reference_trx_path = reference_trx; std::cerr << "[trx-bench] Using reference TRX: " << g_reference_trx_path << std::endl; + + // Pre-inspect reference to avoid repeated loads and to inform skip decisions + { + auto ref = trx::load(g_reference_trx_path); + g_reference_streamline_count = ref->num_streamlines(); + g_reference_has_dpv = !ref->data_per_vertex.empty(); + ref->close(); + std::cerr << "[trx-bench] Reference: " << g_reference_streamline_count + << " streamlines, dpv=" << (g_reference_has_dpv ? "yes" : "no") << std::endl; + } // Enable verbose logging if requested if (verbose) { diff --git a/bench/plot_bench.R b/bench/plot_bench.R index 09ade27..53c6a6c 100755 --- a/bench/plot_bench.R +++ b/bench/plot_bench.R @@ -69,19 +69,19 @@ parse_args <- function() { list(bench_dir = bench_dir, out_dir = out_dir) } -#' Convert benchmark time to milliseconds -time_to_ms <- function(bench) { +#' Convert benchmark time to seconds +time_to_s <- function(bench) { value <- bench$real_time unit <- bench$time_unit - + multiplier <- switch(unit, - "ns" = 1e-6, - "us" = 1e-3, - "ms" = 1, - "s" = 1e3, - 1e-6 # default to nanoseconds + "ns" = 1e-9, + "us" = 1e-6, + "ms" = 1e-3, + "s" = 1, + 1e-9 # default to nanoseconds ) - + value * multiplier } @@ -131,7 +131,7 @@ load_benchmarks <- function(bench_dir) { row <- list( name = name, base = parse_base_name(name), - real_time_ms = time_to_ms(bench), + real_time_s = time_to_s(bench), streamlines = bench$streamlines %||% NA, length_profile = bench$length_profile %||% NA, compression = bench$compression %||% NA, @@ -170,17 +170,71 @@ load_benchmarks <- function(bench_dir) { df } +#' Estimate file sizes for missing streamline counts via linear extrapolation. +#' Uses a per-(dps, dpv, compression, group_case) linear model fit to measured +#' data to fill in counts not present in the benchmark results (e.g. 10M). +estimate_missing_file_sizes <- function(sub_df) { + target_sl <- sort(unique(c(sub_df$streamlines, 10000000))) + + combos <- sub_df %>% + filter(!is.na(file_bytes)) %>% + group_by(dps, dpv, compression, group_case) %>% + filter(n() >= 2) %>% + summarise(.groups = "drop") + + if (nrow(combos) == 0) return(NULL) + + estimated_rows <- list() + + for (i in seq_len(nrow(combos))) { + d <- combos$dps[i] + v <- combos$dpv[i] + comp <- combos$compression[i] + gc <- combos$group_case[i] + + existing <- sub_df %>% + filter(dps == d, dpv == v, compression == comp, + group_case == gc, !is.na(file_bytes)) + + missing_sl <- setdiff(target_sl, existing$streamlines) + if (length(missing_sl) == 0) next + + fit <- lm(file_bytes ~ streamlines, data = existing) + pred <- predict(fit, newdata = data.frame(streamlines = as.numeric(missing_sl))) + + template <- existing[1, , drop = FALSE] + for (j in seq_along(missing_sl)) { + row <- template + row$streamlines <- missing_sl[j] + row$file_bytes <- max(0, pred[j]) + row$estimated <- TRUE + estimated_rows[[length(estimated_rows) + 1]] <- row + } + } + + if (length(estimated_rows) == 0) return(NULL) + bind_rows(estimated_rows) +} + #' Plot file sizes plot_file_sizes <- function(df, out_dir) { sub_df <- df %>% filter(base == "BM_TrxFileSize_Float16") %>% - filter(!is.na(file_bytes), !is.na(streamlines)) - + filter(!is.na(file_bytes), !is.na(streamlines)) %>% + mutate(estimated = FALSE) + if (nrow(sub_df) == 0) { cat("No BM_TrxFileSize_Float16 results found, skipping file size plot\n") return(invisible(NULL)) } - + + # Extrapolate to missing streamline counts (e.g. 10M not present in benchmark data) + est_df <- estimate_missing_file_sizes(sub_df) + if (!is.null(est_df)) { + cat("Added", nrow(est_df), "estimated file size entries (linear extrapolation)\n") + sub_df <- bind_rows(sub_df, est_df) + } + sub_df <- sub_df %>% mutate( file_mb = file_bytes / 1e6, @@ -188,34 +242,49 @@ plot_file_sizes <- function(df, out_dir) { group_label = recode( as.character(ifelse(is.na(group_case), "0", as.character(group_case))), !!!GROUP_LABELS), - dp_label = sprintf("dpv=%d, dps=%d", as.integer(dpv), as.integer(dps)) + dp_label = sprintf("dpv=%d, dps=%d", as.integer(dpv), as.integer(dps)), + measured = ifelse(estimated, "estimated", "measured"), + streamlines_f = factor( + streamlines, + levels = sort(as.numeric(unique(streamlines))), + labels = label_number(scale = 1e-6, suffix = "M")(sort(as.numeric(unique(streamlines)))) + ) ) n_group_levels <- length(unique(sub_df$group_label)) - plot_height <- if (n_group_levels > 1) 5 + 3 * n_group_levels else 7 - - p <- ggplot(sub_df, aes(x = streamlines, y = file_mb, - color = dp_label)) + - geom_line(linewidth = 0.8) + - geom_point(size = 2) + - facet_grid(group_label ~ compression_label) + - scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + + plot_height <- if (n_group_levels > 1) 4 + 2 * n_group_levels else 6 + + p <- ggplot(sub_df, aes(x = streamlines_f, y = file_mb, + fill = dp_label, + linetype = compression_label, + alpha = measured)) + + geom_hline(yintercept = 18500, color = "firebrick", linetype = "dashed", linewidth = 0.8) + + annotate("text", x = Inf, y = 18500, label = "TCK reference 10M (18.5 GB)", + hjust = 1.05, vjust = -0.4, color = "firebrick", size = 3) + + geom_col(position = "dodge", color = "grey30", linewidth = 0.5) + + facet_wrap(~group_label, ncol = 1) + scale_y_continuous(labels = label_number()) + + scale_alpha_manual( + values = c("measured" = 0.9, "estimated" = 0.45), + name = "" + ) + labs( title = "TRX file size vs streamlines (float16 positions)", x = "Streamlines", y = "File size (MB)", - color = "Data per streamline/vertex" + fill = "Data per streamline/vertex", + linetype = "Compression" ) + theme_bw() + theme( legend.position = "bottom", legend.box = "vertical", - strip.background = element_rect(fill = "grey90") + strip.background = element_rect(fill = "grey90"), + axis.text.x = element_text(angle = 45, hjust = 1) ) out_path <- file.path(out_dir, "trx_size_vs_streamlines.png") - ggsave(out_path, p, width = 12, height = plot_height, dpi = 160) + ggsave(out_path, p, width = 10, height = plot_height, dpi = 160) cat("Saved:", out_path, "\n") } @@ -223,64 +292,65 @@ plot_file_sizes <- function(df, out_dir) { plot_translate_write <- function(df, out_dir) { sub_df <- df %>% filter(base == "BM_TrxStream_TranslateWrite") %>% - filter(!is.na(real_time_ms), !is.na(streamlines)) - + filter(!is.na(real_time_s), !is.na(streamlines)) + if (nrow(sub_df) == 0) { cat("No BM_TrxStream_TranslateWrite results found, skipping translate plots\n") return(invisible(NULL)) } - + sub_df <- sub_df %>% mutate( group_label = recode(as.character(group_case), !!!GROUP_LABELS), dp_label = sprintf("dpv=%d, dps=%d", as.integer(dpv), as.integer(dps)), - rss_mb = max_rss_kb / 1024 + rss_gb = max_rss_kb / (1024 * 1024), + streamlines_f = factor( + streamlines, + levels = sort(as.numeric(unique(streamlines))), + labels = label_number(scale = 1e-6, suffix = "M")(sort(as.numeric(unique(streamlines)))) + ) ) - + # Time plot - p_time <- ggplot(sub_df, aes(x = streamlines, y = real_time_ms, - color = dp_label)) + - geom_line(linewidth = 0.8) + - geom_point(size = 2) + + p_time <- ggplot(sub_df, aes(x = streamlines_f, y = real_time_s, fill = dp_label)) + + geom_col(position = "dodge") + facet_wrap(~group_label, ncol = 2) + - scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + scale_y_continuous(labels = label_number()) + labs( title = "Translate + stream write throughput", x = "Streamlines", - y = "Time (ms)", - color = "Data per point" + y = "Time (s)", + fill = "Data per point" ) + theme_bw() + theme( legend.position = "bottom", - strip.background = element_rect(fill = "grey90") + strip.background = element_rect(fill = "grey90"), + axis.text.x = element_text(angle = 45, hjust = 1) ) - + out_path <- file.path(out_dir, "trx_translate_write_time.png") ggsave(out_path, p_time, width = 12, height = 5, dpi = 160) cat("Saved:", out_path, "\n") - + # RSS plot - p_rss <- ggplot(sub_df, aes(x = streamlines, y = rss_mb, - color = dp_label)) + - geom_line(linewidth = 0.8) + - geom_point(size = 2) + + p_rss <- ggplot(sub_df, aes(x = streamlines_f, y = rss_gb, fill = dp_label)) + + geom_col(position = "dodge") + facet_wrap(~group_label, ncol = 2) + - scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + scale_y_continuous(labels = label_number()) + labs( title = "Translate + stream write memory usage", x = "Streamlines", - y = "Max RSS (MB)", - color = "Data per point" + y = "Max RSS (GB)", + fill = "Data per point" ) + theme_bw() + theme( legend.position = "bottom", - strip.background = element_rect(fill = "grey90") + strip.background = element_rect(fill = "grey90"), + axis.text.x = element_text(angle = 45, hjust = 1) ) - + out_path <- file.path(out_dir, "trx_translate_write_rss.png") ggsave(out_path, p_rss, width = 12, height = 5, dpi = 160) cat("Saved:", out_path, "\n") @@ -309,7 +379,7 @@ load_query_timings <- function(jsonl_path) { dps = obj$dps %||% NA, dpv = obj$dpv %||% NA, slab_thickness_mm = obj$slab_thickness_mm %||% NA, - timings_ms = I(list(unlist(obj$timings_ms))) + timings_s = I(list(unlist(obj$timings_ms) / 1000)) ) }, error = function(e) NULL) }) @@ -350,55 +420,49 @@ load_all_query_timings <- function(bench_dir) { } #' Plot query timing distributions -plot_query_timings <- function(bench_dir, out_dir, group_case = 0, dpv = 0, dps = 0) { +plot_query_timings <- function(bench_dir, out_dir) { df <- load_all_query_timings(bench_dir) - + if (is.null(df) || nrow(df) == 0) { cat("No query_timings*.jsonl found or empty, skipping query timing plot\n") return(invisible(NULL)) } - - # Filter by specified conditions - df_filtered <- df %>% - filter( - group_case == !!group_case, - dpv == !!dpv, - dps == !!dps - ) - - if (nrow(df_filtered) == 0) { - cat("No query timings matching filters (group_case=", group_case, - ", dpv=", dpv, ", dps=", dps, "), skipping plot\n", sep = "") - return(invisible(NULL)) - } - - # Expand timings into long format - timing_data <- df_filtered %>% - mutate(streamlines_label = format(streamlines, big.mark = ",")) %>% - select(streamlines, streamlines_label, timings_ms) %>% - unnest(timings_ms) %>% - group_by(streamlines, streamlines_label) %>% - mutate(query_id = row_number()) %>% + + # Expand timings into long format, keeping all group/dpv/dps combinations + timing_data <- df %>% + mutate( + streamlines_label = factor( + label_number(scale = 1e-6, suffix = "M")(streamlines), + levels = label_number(scale = 1e-6, suffix = "M")(sort(as.numeric(unique(streamlines)))) + ), + dp_label = sprintf("dpv=%d, dps=%d", as.integer(dpv), as.integer(dps)), + group_label = recode(as.character(group_case), !!!GROUP_LABELS) + ) %>% + select(streamlines, streamlines_label, dp_label, group_label, timings_s) %>% + unnest(timings_s) %>% ungroup() - - # Create boxplot - group_label <- GROUP_LABELS[as.character(group_case)] - - p <- ggplot(timing_data, aes(x = streamlines_label, y = timings_ms)) + + + n_groups <- length(unique(timing_data$group_label)) + n_dp <- length(unique(timing_data$dp_label)) + plot_height <- max(6, 3 * n_groups) + plot_width <- max(10, 4 * n_dp) + + p <- ggplot(timing_data, aes(x = streamlines_label, y = timings_s)) + geom_boxplot(fill = "steelblue", alpha = 0.7, outlier.size = 0.5) + + facet_grid(group_label ~ dp_label) + labs( - title = sprintf("Slab query timings (%s, dpv=%d, dps=%d)", - group_label, dpv, dps), + title = "Slab query timings by group and data profile", x = "Streamlines", - y = "Per-slab query time (ms)" + y = "Per-slab query time (s)" ) + theme_bw() + theme( - axis.text.x = element_text(angle = 45, hjust = 1) + axis.text.x = element_text(angle = 45, hjust = 1), + strip.background = element_rect(fill = "grey90") ) - + out_path <- file.path(out_dir, "trx_query_slab_timings.png") - ggsave(out_path, p, width = 10, height = 6, dpi = 160) + ggsave(out_path, p, width = plot_width, height = plot_height, dpi = 160) cat("Saved:", out_path, "\n") } @@ -421,7 +485,7 @@ main <- function() { # Generate plots plot_file_sizes(df, args$out_dir) plot_translate_write(df, args$out_dir) - plot_query_timings(args$bench_dir, args$out_dir, group_case = 0, dpv = 0, dps = 0) + plot_query_timings(args$bench_dir, args$out_dir) cat("\nDone! Plots saved to:", args$out_dir, "\n") } diff --git a/bench/run_benchmarks.sh b/bench/run_benchmarks.sh index 52272d3..e8d8b70 100755 --- a/bench/run_benchmarks.sh +++ b/bench/run_benchmarks.sh @@ -140,7 +140,7 @@ run_benchmark() { # Build profile environment defaults (users can override by exporting env vars) if [[ "$PROFILE" == "core" ]]; then - CORE_ENV="TRX_BENCH_PROFILE=${TRX_BENCH_PROFILE:-core} TRX_BENCH_MAX_STREAMLINES=${TRX_BENCH_MAX_STREAMLINES:-5000000} TRX_BENCH_SKIP_DPV_AT=${TRX_BENCH_SKIP_DPV_AT:-1000000} TRX_BENCH_QUERY_CACHE_MAX=${TRX_BENCH_QUERY_CACHE_MAX:-5} TRX_BENCH_CORE_INCLUDE_BUNDLES=${TRX_BENCH_CORE_INCLUDE_BUNDLES:-0} TRX_BENCH_INCLUDE_CONNECTOME=${TRX_BENCH_INCLUDE_CONNECTOME:-0} TRX_BENCH_CORE_DPV_MAX_STREAMLINES=${TRX_BENCH_CORE_DPV_MAX_STREAMLINES:-1000000} TRX_BENCH_CORE_ZIP_MAX_STREAMLINES=${TRX_BENCH_CORE_ZIP_MAX_STREAMLINES:-1000000}" + CORE_ENV="TRX_BENCH_PROFILE=${TRX_BENCH_PROFILE:-core} TRX_BENCH_MAX_STREAMLINES=${TRX_BENCH_MAX_STREAMLINES:-10000000} TRX_BENCH_QUERY_CACHE_MAX=${TRX_BENCH_QUERY_CACHE_MAX:-5} TRX_BENCH_CORE_INCLUDE_BUNDLES=${TRX_BENCH_CORE_INCLUDE_BUNDLES:-0} TRX_BENCH_INCLUDE_CONNECTOME=${TRX_BENCH_INCLUDE_CONNECTOME:-0} TRX_BENCH_CORE_ZIP_MAX_STREAMLINES=${TRX_BENCH_CORE_ZIP_MAX_STREAMLINES:-1000000}" else CORE_ENV="TRX_BENCH_PROFILE=${TRX_BENCH_PROFILE:-full} TRX_BENCH_MAX_STREAMLINES=${TRX_BENCH_MAX_STREAMLINES:-10000000} TRX_BENCH_SKIP_DPV_AT=${TRX_BENCH_SKIP_DPV_AT:-10000000} TRX_BENCH_QUERY_CACHE_MAX=${TRX_BENCH_QUERY_CACHE_MAX:-10} TRX_BENCH_INCLUDE_CONNECTOME=${TRX_BENCH_INCLUDE_CONNECTOME:-1}" fi diff --git a/docs/_static/benchmarks/trx_query_slab_timings.png b/docs/_static/benchmarks/trx_query_slab_timings.png index b950715..8c3f872 100644 Binary files a/docs/_static/benchmarks/trx_query_slab_timings.png and b/docs/_static/benchmarks/trx_query_slab_timings.png differ diff --git a/docs/_static/benchmarks/trx_size_vs_streamlines.png b/docs/_static/benchmarks/trx_size_vs_streamlines.png index f43a93f..e545ebb 100644 Binary files a/docs/_static/benchmarks/trx_size_vs_streamlines.png and b/docs/_static/benchmarks/trx_size_vs_streamlines.png differ diff --git a/docs/_static/benchmarks/trx_translate_write_rss.png b/docs/_static/benchmarks/trx_translate_write_rss.png index 1a4b8bc..d30c44f 100644 Binary files a/docs/_static/benchmarks/trx_translate_write_rss.png and b/docs/_static/benchmarks/trx_translate_write_rss.png differ diff --git a/docs/_static/benchmarks/trx_translate_write_time.png b/docs/_static/benchmarks/trx_translate_write_time.png index cfed582..1a0369c 100644 Binary files a/docs/_static/benchmarks/trx_translate_write_time.png and b/docs/_static/benchmarks/trx_translate_write_time.png differ diff --git a/docs/api_layers.rst b/docs/api_layers.rst new file mode 100644 index 0000000..20b5366 --- /dev/null +++ b/docs/api_layers.rst @@ -0,0 +1,82 @@ +API Layers +========== + +trx-cpp provides three complementary interfaces. Understanding when to use +each one avoids boilerplate and makes code easier to reason about. + +AnyTrxFile — runtime-typed +-------------------------- + +:class:`trx::AnyTrxFile` is the simplest entry point. It reads the positions +dtype directly from the file and exposes all arrays as :class:`trx::TypedArray` +objects with a ``dtype`` string field. Use this when you only have a file path +and do not need to perform arithmetic on positions at the C++ type level. + +.. code-block:: cpp + + auto trx = trx::load_any("tracks.trx"); + + std::cout << trx.positions.dtype << "\n"; // e.g. "float16" + std::cout << trx.num_streamlines() << "\n"; + + // Access positions as any floating-point type you choose. + auto pos = trx.positions.as_matrix(); // (NB_VERTICES, 3) + + trx.close(); + +TrxFile
— compile-time typed +--------------------------------- + +:class:`trx::TrxFile` is templated on the positions dtype (``Eigen::half``, +``float``, or ``double``). Positions and DPV arrays are exposed as +``Eigen::Matrix`` directly — no element-wise conversion. Use this +when the dtype is known, or when you are performing per-vertex arithmetic and +want the compiler to enforce type consistency. + +.. code-block:: cpp + + auto reader = trx::load("tracks.trx"); + auto& trx = *reader; + + // trx.streamlines->_data is Eigen::Matrix + std::cout << trx.streamlines->_data.rows() << " vertices\n"; + + reader->close(); + +The recommended typed entry point is :func:`trx::with_trx_reader`, which +detects the dtype at runtime and dispatches to the correct instantiation: + +.. code-block:: cpp + + trx::with_trx_reader("tracks.trx", [](auto& trx) { + // trx is TrxFile
for the detected dtype + std::cout << trx.num_vertices() << "\n"; + }); + +TrxReader
— RAII lifetime management +---------------------------------------- + +:class:`trx::TrxReader` is a thin RAII wrapper around :class:`trx::TrxFile`. +When a TRX zip is loaded, trx-cpp extracts it to a temporary directory. +``TrxReader`` owns that directory and removes it when it goes out of scope, +ensuring no temporary files are leaked. + +In most cases you do not need to instantiate ``TrxReader`` directly. The +convenience functions ``trx::load_any`` and ``trx::with_trx_reader`` handle +the lifetime automatically. ``TrxReader`` is available for advanced use cases +where explicit control over the backing resource lifetime is required — for +example, when passing a ``TrxFile`` across a function boundary and needing the +temporary directory to outlive the calling scope. + +Summary +------- + ++--------------------+----------------------------------+-----------------------------+ +| Class | Dtype | Best for | ++====================+==================================+=============================+ +| ``AnyTrxFile`` | Runtime (read from file) | Inspection, generic tools | ++--------------------+----------------------------------+-----------------------------+ +| ``TrxFile
`` | Compile-time | Per-vertex computation | ++--------------------+----------------------------------+-----------------------------+ +| ``TrxReader
`` | Compile-time + RAII cleanup | Explicit lifetime control | ++--------------------+----------------------------------+-----------------------------+ diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst index e10ddfa..c29cf61 100644 --- a/docs/benchmarks.rst +++ b/docs/benchmarks.rst @@ -5,26 +5,20 @@ This page documents the benchmarking suite and how to interpret the results. The benchmarks are designed for realistic tractography workloads (HPC scale), not for CI. They focus on file size, throughput, and interactive spatial queries. -Data model ----------- +Data +---- -All benchmarks synthesize smooth, slightly curved streamlines in a realistic -field of view: - -- **Lengths:** random between 20 and 500 mm (profiles skew short/medium/long) -- **Field of view:** x = [-70, 70], y = [-108, 79], z = [-60, 75] (mm, RAS+) -- **Streamline counts:** 100k, 500k, 1M, 5M, 10M -- **Groups:** none, 80 bundle groups, or 4950 connectome groups (100 regions) -- **DPV/DPS:** either present (1 value) or absent - -Positions are stored as float16 to highlight storage efficiency. +All benchmarks use a real dataset from the Human Connectome Project (Young Adult dataset). +The original size of the `tck` file is 18.5GB. +The streamlines were generated by `tckgen` and the data per vertex were generated by `tcksift2`. +This data was converted to TRX format (float16 positions) with nibabel. TRX size vs streamline count ---------------------------- -This benchmark writes TRX files with float16 positions and measures the final -on-disk size for different streamline counts. It compares short/medium/long -length profiles, DPV/DPS presence, and zip compression (store vs deflate). +This benchmark creates subsets of the reference TRX file and measures the +on-disk size for different streamline counts. This also calculates the size +with and without additional features such as DPS, DPV and groups. .. figure:: _static/benchmarks/trx_size_vs_streamlines.png :alt: TRX file size vs streamlines @@ -35,10 +29,26 @@ length profiles, DPV/DPS presence, and zip compression (store vs deflate). Translate + stream write throughput ----------------------------------- -This benchmark loads a TRX file, iterates through every streamline, translates -each point by +1 mm in x/y/z, and streams the result into a new TRX file. It -reports total wall time and max RSS so researchers can understand throughput -and memory pressure on both clusters and laptops. +This benchmark simulates loading a TRX file, applying a spatial transform, and saving it +to a new TRX file (preserving DPV, DPS, groups, etc.). The workflow is: + +1. Load the input TRX file (decompresses to a temp directory). +2. Copy all metadata files (header, offsets) and data array directories (``dps``, ``dpv``, + ``groups``) to a fresh output directory using kernel-level filesystem copies + (``fcopyfile`` on macOS, ``sendfile`` on Linux). The data never passes through the + process's mapped address space. +3. Release DPS/DPV mmaps immediately after the copy so their pages are not counted + against the benchmark's RSS. +4. Iterate through every position point, apply a +1 mm translation in x/y/z, and + stream the translated positions directly to the output directory. +5. Pack the output directory into a ``.trx`` zip archive using ``zip_source_file`` + (libzip buffered I/O), which reads files in chunks without mapping them into + the process address space. + +The reported ``max_rss_kb`` is the **per-iteration RSS delta** — the peak increase in +resident memory observed during that single iteration, measured with +``mach_task_basic_info`` (macOS) or ``/proc/self/status`` ``VmRSS`` (Linux). This avoids +contamination from previous benchmark iterations whose pages may still be resident. .. figure:: _static/benchmarks/trx_translate_write_time.png :alt: Translate + stream write time @@ -50,15 +60,25 @@ and memory pressure on both clusters and laptops. :alt: Translate + stream write RSS :align: center - Max RSS during translate + stream write. + Peak RSS delta during translate + stream write. Because DPS/DPV are copied at the + filesystem level and their mmaps are released before position processing begins, only + the positions chunk buffer (configurable via ``TRX_BENCH_CHUNK_BYTES``, default 1 GB) + and the output stream contribute to the measured RSS. Spatial slab query latency -------------------------- -This benchmark precomputes per-streamline AABBs and then issues 100 spatial -queries using 5 mm slabs that sweep through the tractogram volume. Each slab -query mimics a GUI slice update and records its timing so distributions can be -visualized. +Streamlines need to be visualized. +One common method for this is to break the tractogram into spatial slabs and query each slab individually. +Both `mrview` and DSI Studio have methods for showing sliced streamlines in 2D. + +The `trx-cpp` library provides a few methods for subsetting a large set of streamlines. +This benchmark uses the Axis-aligned bounding box (AABB) method to create subsets of the tractogram. + +After computing the AABBs, it issues 20 spatial queries using 5 mm slabs that sweep through the tractogram volume. +Each slab query mimics a GUI slice update and records its timing so distributions can be visualized. +To keep results representative of interactive use, each query returns at most 500 randomly sampled +streamlines from all those intersecting the slab (configurable via ``TRX_BENCH_MAX_QUERY_STREAMLINES``). .. figure:: _static/benchmarks/trx_query_slab_timings.png :alt: Slab query timings @@ -66,126 +86,84 @@ visualized. Distribution of per-slab query latency. -Performance characteristics ---------------------------- - -Benchmark results vary significantly based on storage performance: - -**SSD (solid-state drives):** -- **CPU-bound**: Disk writes complete faster than streamline generation -- High CPU utilization (~100%) -- Results reflect pure computational throughput - -**HDD (spinning disks):** -- **I/O-bound**: Disk writes are the bottleneck -- Low CPU utilization (~5-10%) -- Results reflect realistic workstation performance with storage latency - -Both scenarios are valuable. SSD results show the library's maximum throughput, -while HDD results show real-world performance on cost-effective storage. On -Linux, monitor I/O wait time with ``iostat -x 1`` to identify the bottleneck. - -For spinning disks or network filesystems, you may want to increase buffer sizes -to amortize I/O latency. Set ``TRX_BENCH_BUFFER_MULTIPLIER`` to use larger -buffers (e.g., ``TRX_BENCH_BUFFER_MULTIPLIER=4`` uses 4× the default buffer -sizes). Running the benchmarks ---------------------- -Build and run the benchmarks, then plot results with matplotlib: +Build into ``build-release`` (the default expected by ``run_benchmarks.sh``), then use +the helper script to run all groups and generate plots: .. code-block:: bash - cmake -S . -B build -DTRX_BUILD_BENCHMARKS=ON - cmake --build build --target bench_trx_stream + cmake -S . -B build-release \ + -DCMAKE_BUILD_TYPE=Release \ + -DTRX_BUILD_BENCHMARKS=ON + cmake --build build-release --target bench_trx_realdata - # Run benchmarks (this can be long for large datasets). - ./build/bench/bench_trx_stream \ - --benchmark_out=bench/results.json \ - --benchmark_out_format=json - - # For slower storage (HDD, NFS), use larger buffers: - TRX_BENCH_BUFFER_MULTIPLIER=4 \ - ./build/bench/bench_trx_stream \ - --benchmark_out=bench/results_hdd.json \ - --benchmark_out_format=json - - # Capture per-slab timings for query distributions. - TRX_QUERY_TIMINGS_PATH=bench/query_timings.jsonl \ - ./build/bench/bench_trx_stream \ - --benchmark_filter=BM_TrxQueryAabb_Slabs \ - --benchmark_out=bench/results.json \ - --benchmark_out_format=json - - # Optional: record RSS samples for file-size runs. - TRX_RSS_SAMPLES_PATH=bench/rss_samples.jsonl \ - TRX_RSS_SAMPLE_EVERY=50000 \ - TRX_RSS_SAMPLE_MS=500 \ - ./build/bench/bench_trx_stream \ - --benchmark_filter=BM_TrxFileSize_Float16 \ - --benchmark_out=bench/results.json \ - --benchmark_out_format=json + # Run all benchmark groups (filesize, translate+write, query) in sequence. + ./bench/run_benchmarks.sh --reference /path/to/reference.trx # Generate plots into docs/_static/benchmarks. - python bench/plot_bench.py bench/results.json \ - --query-json bench/query_timings.jsonl \ - --out-dir docs/_static/benchmarks - -The query plot defaults to the "no groups, no DPV/DPS" case. Use -``--group-case``, ``--dpv``, and ``--dps`` in ``plot_bench.py`` to select other -scenarios. - -Environment variables ---------------------- - -The benchmark suite supports several environment variables for customization: - -**Multiprocessing:** - -- ``TRX_BENCH_PROCESSES`` (default: 1): Number of processes for parallel shard - generation. Recommended: number of physical cores. -- ``TRX_BENCH_MP_MIN_STREAMLINES`` (default: 1000000): Minimum streamline count - to enable multiprocessing. Below this threshold, single-process mode is used. -- ``TRX_BENCH_KEEP_SHARDS`` (default: 0): Set to 1 to preserve shard directories - after merging for debugging. -- ``TRX_BENCH_SHARD_WAIT_MS`` (default: 10000): Timeout in milliseconds for - waiting for shard completion markers. - -**Buffering (for slow storage):** - -- ``TRX_BENCH_BUFFER_MULTIPLIER`` (default: 1): Scales position and metadata - buffer sizes. Use larger values (2-8) for spinning disks or network - filesystems to reduce I/O latency. Example: multiplier=4 uses 64 MB → 256 MB - for small datasets, 256 MB → 1 GB for 1M streamlines, 2 GB → 8 GB for 5M+ - streamlines. - -**Performance tuning:** - -- ``TRX_BENCH_THREADS`` (default: hardware_concurrency): Worker threads for - streamline generation within each process. -- ``TRX_BENCH_BATCH`` (default: 1000): Streamlines per batch in the producer- - consumer queue. -- ``TRX_BENCH_QUEUE_MAX`` (default: 8): Maximum batches in flight between - producers and consumer. - -**Dataset control:** - -- ``TRX_BENCH_ONLY_STREAMLINES`` (default: 0): If nonzero, benchmark only this - streamline count instead of the full range. -- ``TRX_BENCH_MAX_STREAMLINES`` (default: 10000000): Maximum streamline count - to benchmark. Use smaller values for faster iteration. -- ``TRX_BENCH_SKIP_ZIP_AT`` (default: 5000000): Skip zip compression for - streamline counts at or above this threshold. - -**Logging and diagnostics:** - -- ``TRX_BENCH_LOG`` (default: 0): Enable benchmark progress logging to stderr. -- ``TRX_BENCH_CHILD_LOG`` (default: 0): Enable logging from child processes in - multiprocess mode. -- ``TRX_BENCH_LOG_PROGRESS_EVERY`` (default: 0): Log progress every N - streamlines. - -When running with multiprocessing, the benchmark uses -``finalize_directory_persistent()`` to write shard outputs without removing -pre-created directories, avoiding race conditions in the parallel workflow. + Rscript bench/plot_bench.R --bench-dir bench --out-dir docs/_static/benchmarks + +``run_benchmarks.sh`` accepts the following flags: + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Flag + - Description + * - ``--reference PATH`` + - Path to the reference ``.trx`` file (required). + * - ``--out-dir DIR`` + - Output directory for JSON results (default: ``bench/``). + * - ``--build-dir DIR`` + - Path to the CMake build directory (default: ``build-release``). + * - ``--profile core|full`` + - ``core`` (default) skips DPV and compression for large streamline counts to keep + run time manageable; ``full`` runs every combination. + * - ``--verbose`` + - Enable per-benchmark progress logging. + +Environment variables (set before calling ``run_benchmarks.sh`` or the binary directly): + +.. list-table:: + :widths: 35 65 + :header-rows: 1 + + * - Variable + - Description + * - ``TRX_BENCH_BUFFER_MULTIPLIER`` + - Scale I/O buffer sizes for slow storage (HDD, NFS). Default: ``1``. + * - ``TRX_BENCH_CHUNK_BYTES`` + - Positions chunk size for translate+write (default: ``1073741824`` = 1 GB). + * - ``TRX_BENCH_MAX_STREAMLINES`` + - Cap the maximum streamline count tested (default: ``10000000``). + * - ``TRX_QUERY_TIMINGS_PATH`` + - If set, per-slab query latencies are appended as JSONL to this path. + * - ``TRX_RSS_SAMPLES_PATH`` + - If set, time-series RSS samples during file-size runs are appended as JSONL. + +Example — slower storage with a reduced streamline cap: + +.. code-block:: bash + + TRX_BENCH_BUFFER_MULTIPLIER=4 \ + TRX_BENCH_MAX_STREAMLINES=5000000 \ + ./bench/run_benchmarks.sh \ + --reference /path/to/reference.trx \ + --out-dir bench/results_hdd + +RSS measurement +~~~~~~~~~~~~~~~ + +The ``max_rss_kb`` counter in every benchmark reports the **per-iteration RSS delta**: +the maximum increase in resident memory observed during a single benchmark iteration, +sampled at the start and end of each iteration. This avoids contamination from earlier +iterations whose pages may still be resident in the OS page cache. + +On macOS the current RSS is read via ``mach_task_basic_info.resident_size``; on Linux +via ``VmRSS`` in ``/proc/self/status``. The ``run_benchmarks.sh`` script runs each +benchmark group (filesize, translate, query) as a separate process to prevent +cross-group RSS accumulation. diff --git a/docs/building.rst b/docs/building.rst index 202af97..03ec2d0 100644 --- a/docs/building.rst +++ b/docs/building.rst @@ -11,16 +11,15 @@ Required: - Eigen3 -Installing dependencies: +Installing dependencies ------------------------ -These examples include installing google test, -but this is only necessary if you want to build the tests. -Similarly, ninja is not strictly necessary, but it is recommended. -zlib is only required if you want to use the NIfTI I/O features. +The examples below include GoogleTest, which is only required when building +the tests. Ninja is optional but recommended. zlib is only required for the +NIfTI I/O features. On Debian-based systems the zip tools have been split into separate packages -on recent ubuntu versions. +on recent Ubuntu versions. .. code-block:: bash @@ -70,7 +69,7 @@ Key CMake options: - ``TRX_BUILD_TESTS``: Build tests (default OFF) - ``TRX_BUILD_DOCS``: Build docs with Doxygen/Sphinx (default OFF) - ``TRX_ENABLE_CLANG_TIDY``: Run clang-tidy during builds (default OFF) -- ``TRX_USE_CONAN```: Use Conan setup in ```cmake/ConanSetup.cmake`` (default OFF) +- ``TRX_USE_CONAN``: Use Conan setup in ``cmake/ConanSetup.cmake`` (default OFF) To use trx-cpp from another CMake project after installation: diff --git a/docs/concepts.rst b/docs/concepts.rst new file mode 100644 index 0000000..fe36314 --- /dev/null +++ b/docs/concepts.rst @@ -0,0 +1,113 @@ +Core Concepts +============= + +This page explains how TRX-cpp represents tractography data on disk and in +memory. + +The TRX format +-------------- + +A TRX file is a ZIP archive (or on-disk directory) whose layout directly +encodes its data model: + +- ``header.json`` — spatial metadata +- ``positions.`` — all streamline vertices in a single flat array +- ``offsets.`` — prefix-sum index from streamlines into positions +- ``dpv/.`` — per-vertex metadata arrays +- ``dps/.`` — per-streamline metadata arrays +- ``groups/.uint32`` — named index sets of streamlines +- ``dpg//.`` — per-group metadata arrays + +Coordinates are stored in **RAS+ world space** (millimeters), matching the +convention used by MRtrix3 ``.tck`` and NIfTI qform outputs. + +Positions array +--------------- + +All streamline vertices are stored in a single flat matrix of shape +``(NB_VERTICES, 3)``. Keeping all vertices contiguous enables efficient +memory mapping and avoids per-streamline allocations. + +In trx-cpp, ``positions`` is backed by a ``mio::shared_mmap_sink`` and +exposed as an ``Eigen::Matrix`` view, giving zero-copy read access after +the file is mapped. + +Offsets and the sentinel +------------------------ + +``offsets`` is a prefix-sum index of length ``NB_STREAMLINES + 1``. Element +*i* is the offset in ``positions`` of the first vertex of streamline *i*. +The final element is a **sentinel** equal to ``NB_VERTICES``, which makes +length computation trivial without special-casing the last streamline: + +.. code-block:: cpp + + size_t length_i = offsets[i + 1] - offsets[i]; + +This design avoids per-streamline allocations and makes slicing the global +positions array fast and uniform. + +Data per vertex (DPV) +--------------------- + +A DPV array stores one value per vertex in ``positions``. It has shape +``(NB_VERTICES, 1)`` for scalar fields or ``(NB_VERTICES, N)`` for +vector-valued fields. Typical uses: + +- FA values along the tract +- Per-point RGB colors +- Confidence or weight measures per vertex + +DPV arrays live under ``dpv/`` and are memory-mapped in the same way as +``positions``. + +Data per streamline (DPS) +------------------------- + +A DPS array stores one value per streamline. It has shape +``(NB_STREAMLINES, 1)`` or ``(NB_STREAMLINES, N)``. Typical uses: + +- Mean FA or average curvature per tract +- Per-streamline cluster labels +- Tractography algorithm weights + +DPS arrays live under ``dps/`` and are mapped into ``MMappedMatrix`` objects. + +Groups +------ + +A group is a named list of streamline indices stored as a ``uint32`` array +under ``groups/``. Groups enable sparse, overlapping labeling: a streamline +can belong to multiple groups, and groups can have different sizes. Typical +uses: + +- Bundle labels (``CST_L``, ``CC``, ``SLF_R``, ...) +- Cluster assignments from QuickBundles or similar algorithms +- Connectivity subsets (streamlines connecting two ROIs) + +Data per group (DPG) +-------------------- + +DPG attaches metadata to a group. Each group folder ``dpg//`` can +contain any number of scalar or vector arrays. Typical uses: + +- Mean FA across the bundle +- Per-bundle display color +- Volume or surface-area estimates + +In trx-cpp, groups are ``MMappedMatrix`` objects and DPG fields are +``MMappedMatrix
`` entries, both memory-mapped. + +Header +------ + +``header.json`` stores: + +- ``VOXEL_TO_RASMM`` — 4×4 affine mapping voxel indices to RAS+ world + coordinates (mm) +- ``DIMENSIONS`` — reference image grid dimensions as three ``uint16`` values +- ``NB_STREAMLINES`` — number of streamlines (``uint32``) +- ``NB_VERTICES`` — total number of vertices across all streamlines (``uint64``) + +The header is primarily for human readability and downstream compatibility. +The authoritative sizes come from the array dimensions themselves. diff --git a/docs/downstream_usage.rst b/docs/downstream_usage.rst deleted file mode 100644 index 6e1dc9b..0000000 --- a/docs/downstream_usage.rst +++ /dev/null @@ -1,449 +0,0 @@ -Downstream Usage -================ - -Writing MRtrix-style streamlines to TRX ---------------------------------------- - -MRtrix3 streamlines are created by ``MR::DWI::Tractography::Tracking::Exec``, -which appends points to a per-streamline container as tracking progresses. -That container is ``MR::DWI::Tractography::Tracking::GeneratedTrack``, which is -a ``std::vector`` with some tracking metadata -fields (seed index and status). During tracking, each call that advances the -streamline pushes the current position into this vector, so the resulting -streamline is just an ordered list of 3D points. - -In practice: - -- Streamline points are stored as ``Eigen::Vector3f`` entries in - ``GeneratedTrack``. -- The tracking code appends ``method.pos`` into that vector as each step - completes (seed point first, then subsequent vertices). -- The final output is a list of accepted ``GeneratedTrack`` instances, each - representing one streamline. - -TRX stores streamlines as a single ``positions`` array of shape ``(NB_VERTICES, 3)`` -and an ``offsets`` array of length ``(NB_STREAMLINES + 1)`` that provides the -prefix-sum offsets of each streamline in ``positions``. The example below shows -one idea on how to convert a list of MRtrix ``GeneratedTrack`` streamlines into a TRX file. - -.. code-block:: cpp - - #include - #include "dwi/tractography/tracking/generated_track.h" - - using MR::DWI::Tractography::Tracking::GeneratedTrack; - - void write_trx_from_mrtrix(const std::vector &tracks, - const std::string &out_path) { - // Count accepted streamlines and total vertices. - std::vector accepted; - accepted.reserve(tracks.size()); - size_t total_vertices = 0; - for (const auto &tck : tracks) { - if (tck.get_status() != GeneratedTrack::status_t::ACCEPTED) { - continue; - } - accepted.push_back(&tck); - total_vertices += tck.size(); - } - - const size_t nb_streamlines = accepted.size(); - const size_t nb_vertices = total_vertices; - - // Allocate a TRX file (float positions) with the desired sizes. - trx::TrxFile trx(nb_vertices, nb_streamlines); - - auto &positions = trx.streamlines->_data; // (NB_VERTICES, 3) - auto &offsets = trx.streamlines->_offsets; // (NB_STREAMLINES + 1, 1) - auto &lengths = trx.streamlines->_lengths; // (NB_STREAMLINES, 1) - - size_t cursor = 0; - offsets(0) = 0; - for (size_t i = 0; i < nb_streamlines; ++i) { - const auto &tck = *accepted[i]; - lengths(i) = static_cast(tck.size()); - offsets(i + 1) = offsets(i) + tck.size(); - - for (size_t j = 0; j < tck.size(); ++j, ++cursor) { - positions(cursor, 0) = tck[j].x(); - positions(cursor, 1) = tck[j].y(); - positions(cursor, 2) = tck[j].z(); - } - } - - trx.save(out_path, ZIP_CM_STORE); - trx.close(); - } - - -Streaming TRX from MRtrix tckgen --------------------------------- - -MRtrix ``tckgen`` writes streamlines as they are generated. To stream into TRX -without buffering all streamlines in memory, use ``trx::TrxStream`` to append -streamlines and finalize once tracking completes. This mirrors how the tck -writer works. - -.. code-block:: cpp - - #include - #include "dwi/tractography/tracking/generated_track.h" - - using MR::DWI::Tractography::Tracking::GeneratedTrack; - - trx::TrxStream trx_stream; - - // Called for each accepted streamline. - void on_streamline(const GeneratedTrack &tck) { - std::vector xyz; - xyz.reserve(tck.size() * 3); - for (const auto &pt : tck) { - xyz.push_back(pt[0]); - xyz.push_back(pt[1]); - xyz.push_back(pt[2]); - } - trx_stream.push_streamline(xyz); - } - - trx_stream.finalize("tracks.trx", ZIP_CM_STORE); - - -Using TRX in DSI Studio ------------------------ - -DSI Studio stores tractography in ``tract_model.cpp`` as a list of per-tract -point arrays and optional cluster assignments. The TRX format maps cleanly onto -this representation: - -- DSI Studio cluster assignments map to TRX ``groups/`` files. Each cluster is a - group containing the indices of streamlines that belong to that cluster. -- Per-streamline values (e.g., DSI's loaded scalar values) map to TRX DPS - (``data_per_streamline``) arrays. -- Per-vertex values (e.g., along-tract scalars) map to TRX DPV - (``data_per_vertex``) arrays. - -This means a TRX file can carry the tract geometry, cluster membership, and -both per-streamline and per-vertex metrics in a single archive, and DSI Studio -can round-trip these fields without custom sidecars. - -Usage sketch (DSI Studio) -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: cpp - - // DSI Studio tract data is stored as std::vector> - // (see tract_model.cpp), with each streamline as an interleaved xyz list. - // Coordinates are in DSI Studio's internal voxel space; convert to RASMM - // as needed (e.g., multiply by voxel size and apply transforms). - std::vector> streamlines = /* DSI Studio tract_data */; - - // Optional per-streamline cluster labels or group membership. - std::vector cluster_ids = /* same length as streamlines */; - - // Convert to TRX positions/offsets. - size_t total_vertices = 0; - for (const auto &sl : streamlines) { - total_vertices += sl.size() / 3; - } - - trx::TrxFile trx(static_cast(total_vertices), - static_cast(streamlines.size())); - auto &positions = trx.streamlines->_data; - auto &offsets = trx.streamlines->_offsets; - auto &lengths = trx.streamlines->_lengths; - - size_t cursor = 0; - offsets(0) = 0; - for (size_t i = 0; i < streamlines.size(); ++i) { - const auto &sl = streamlines[i]; - const size_t points = sl.size() / 3; - lengths(i) = static_cast(points); - offsets(i + 1) = offsets(i) + points; - for (size_t p = 0; p < points; ++p, ++cursor) { - positions(cursor, 0) = sl[p * 3 + 0]; - positions(cursor, 1) = sl[p * 3 + 1]; - positions(cursor, 2) = sl[p * 3 + 2]; - } - } - - // Map cluster labels to TRX groups (one group per label). - std::map> clusters; - for (size_t i = 0; i < cluster_ids.size(); ++i) { - clusters[cluster_ids[i]].push_back(static_cast(i)); - } - for (const auto &kv : clusters) { - trx.add_group_from_indices("cluster_" + std::to_string(kv.first), kv.second); - } - - trx.save("out.trx", ZIP_CM_STORE); - trx.close(); - -Using TRX with nibrary (dmriTrekker) ------------------------------------- - -nibrary (used by dmriTrekker for tractogram handling) could provide TRX -reading and writing support in its tractography I/O layer. TRX fits this -pipeline well because it exposes the same primitives that nibrary uses -internally: a list of streamlines (each a list of 3D points) plus optional -per-streamline and per-vertex fields. - -Coordinate systems: - -- TRX ``positions`` are stored in world space (RASMM), which is RAS+ and matches - the coordinate system used by MRtrix3's ``.tck`` format. nibrary's internal - streamline points are the same coordinates written out by its TCK writer, so - those points map directly to TRX ``positions`` when using the same reference - space. -- TRX header fields ``VOXEL_TO_RASMM`` and ``DIMENSIONS`` should be populated from - the reference image used by dmriTrekker/nibrary so downstream tools interpret - coordinates consistently. - -Usage sketch (nibrary) -~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: cpp - - // nibrary uses Streamline = std::vector and Tractogram = std::vector. - using NIBR::Streamline; - using NIBR::Tractogram; - - Tractogram nibr_streamlines = /* nibrary tractogram data */; - - // Write nibrary streamlines to TRX. - size_t total_vertices = 0; - for (const auto &sl : nibr_streamlines) { - total_vertices += sl.size(); - } - - trx::TrxFile trx_out(static_cast(total_vertices), - static_cast(nibr_streamlines.size())); - auto &positions = trx_out.streamlines->_data; - auto &offsets = trx_out.streamlines->_offsets; - auto &lengths = trx_out.streamlines->_lengths; - - size_t cursor = 0; - offsets(0) = 0; - for (size_t i = 0; i < nibr_streamlines.size(); ++i) { - const auto &sl = nibr_streamlines[i]; - lengths(i) = static_cast(sl.size()); - offsets(i + 1) = offsets(i) + sl.size(); - for (size_t p = 0; p < sl.size(); ++p, ++cursor) { - positions(cursor, 0) = sl[p][0]; - positions(cursor, 1) = sl[p][1]; - positions(cursor, 2) = sl[p][2]; - } - } - - trx_out.save("tracks.trx", ZIP_CM_STORE); - trx_out.close(); - - // Read TRX into nibrary-style streamlines. - auto trx_in = trx::load_any("tracks.trx"); - const auto pos = trx_in.positions.as_matrix(); - const auto offs = trx_in.offsets.as_matrix(); - - Tractogram out_streamlines; - out_streamlines.reserve(trx_in.num_streamlines()); - for (size_t i = 0; i < trx_in.num_streamlines(); ++i) { - const size_t start = static_cast(offs(i, 0)); - const size_t end = static_cast(offs(i + 1, 0)); - Streamline sl; - sl.reserve(end - start); - for (size_t j = start; j < end; ++j) { - sl.push_back({pos(j, 0), pos(j, 1), pos(j, 2)}); - } - out_streamlines.push_back(std::move(sl)); - } - - trx_in.close(); - -Loading MITK Diffusion streamlines into TRX -------------------------------------------- - -MITK Diffusion stores its streamline output in ``StreamlineTrackingFilter`` as a -``BundleType``, which is a ``std::vector`` of ``FiberType`` objects. Each ``FiberType`` -is a ``std::deque>``, i.e., an ordered list of 3D points in -physical space. Converting this to TRX follows the same pattern as any other -list-of-points representation: flatten all points into the TRX ``positions`` -array and build a prefix-sum ``offsets`` array. - -Note on physical space, headers, and affines: - -- MITK streamlines are in physical space (millimeters) using ITK's LPS+ - convention by default. TRX ``positions`` are expected to be in RASMM, so you - should flip the x and y axes when writing TRX (and flip back when reading). - -The sketch below shows how to write a ``BundleType`` to TRX and how to reconstruct -it from a TRX file if needed. - -.. code-block:: cpp - - #include - #include - #include - #include - - using FiberType = std::deque>; - using BundleType = std::vector; - - void mitk_bundle_to_trx(const BundleType &bundle, const std::string &out_path) { - size_t total_vertices = 0; - for (const auto &fiber : bundle) { - total_vertices += fiber.size(); - } - - trx::TrxFile trx(total_vertices, bundle.size()); - auto &positions = trx.streamlines->_data; - auto &offsets = trx.streamlines->_offsets; - auto &lengths = trx.streamlines->_lengths; - - size_t cursor = 0; - offsets(0) = 0; - for (size_t i = 0; i < bundle.size(); ++i) { - const auto &fiber = bundle[i]; - lengths(i) = static_cast(fiber.size()); - offsets(i + 1) = offsets(i) + fiber.size(); - size_t j = 0; - for (const auto &pt : fiber) { - // LPS (MITK/ITK) -> RAS (TRX) - positions(cursor, 0) = -pt[0]; - positions(cursor, 1) = -pt[1]; - positions(cursor, 2) = pt[2]; - ++cursor; - ++j; - } - } - - trx.save(out_path, ZIP_CM_STORE); - trx.close(); - } - - BundleType trx_to_mitk_bundle(const std::string &trx_path) { - auto trx = trx::load_any(trx_path); - const auto positions = trx.positions.as_matrix(); // (NB_VERTICES, 3) - const auto offsets = trx.offsets.as_matrix(); // (NB_STREAMLINES + 1, 1) - - BundleType bundle; - bundle.reserve(trx.num_streamlines()); - - for (size_t i = 0; i < trx.num_streamlines(); ++i) { - const size_t start = static_cast(offsets(i, 0)); - const size_t end = static_cast(offsets(i + 1, 0)); - FiberType fiber; - fiber.resize(end - start); - for (size_t j = start; j < end; ++j) { - // RAS (TRX) -> LPS (MITK/ITK) - fiber[j - start][0] = -positions(j, 0); - fiber[j - start][1] = -positions(j, 1); - fiber[j - start][2] = positions(j, 2); - } - bundle.push_back(std::move(fiber)); - } - - trx.close(); - return bundle; - } - -TRX in ITK-SNAP slice views ----------------------------- - -ITK-SNAP is a very nice image viewer that has not had the ability to visualize -streamlines. It is a very useful tool to check image alignment, especially if -you are working with ITK/ANTs, as it interprets image headers using ITK. - -Streamlines could be added to ITK-SNAP slice views by adding a renderer delegate to -the slice rendering pipeline. DSI Studio and MRview both have the ability to plot -streamlines on slices, but neither use ITK to interpret nifti headers. Someday -when TRX is directly integrated into ITK, ITK-SNAP integration could be used -to check that ``antsApplyTransformsToTRX`` is working correctly. - -Where to maybe integrate: - -- ``GUI/Qt/Components/SliceViewPanel`` sets up the slice view and installs - renderer delegates. -- ``GUI/Renderer/GenericSliceRenderer`` and ``SliceRendererDelegate`` define - the overlay rendering API (lines, polylines, etc.). -- Existing overlays (e.g., ``CrosshairsRenderer`` and - ``PolygonDrawingRenderer``) show how to draw line-based primitives. - -Possible workflow: - -1. Load a TRX file via GUI. -2. Create a new renderer delegate (e.g., ``StreamlineTrajectoryRenderer``) that: - - Filters streamlines that intersect the current slice plane (optionally - using cached AABBs for speed). - - Projects 3D points into slice coordinates using - ``GenericSliceModel::MapImageToSlice`` or - ``GenericSliceModel::MapSliceToImagePhysical``. - - Draws each trajectory with ``DrawPolyLine`` in the render context. -3. Register the delegate in ``SliceViewPanel`` so it renders above the image. - -Coordinate systems: - -- ITK-SNAP uses LPS+ physical coordinates by default. -- TRX stores positions in RAS+ world coordinates, so x/y should be negated when - moving between TRX and ITK-SNAP physical space. - -This design keeps the TRX integration localized to the slice overlay system and -does not require changes to core ITK-SNAP data structures. - -TRX in SlicerDMRI ------------------ - -SlicerDMRI represents tractography as ``vtkPolyData`` inside a -``vtkMRMLFiberBundleNode``. TRX support is implemented by converting TRX -structures to that polydata representation (and back on save). - -High-level mapping: - -- TRX ``positions`` + ``offsets`` map to polydata points and polyline cells. - Each streamline becomes one line cell; point coordinates are stored in RAS+. -- TRX DPV (data-per-vertex) becomes ``PointData`` arrays on the polydata. -- TRX DPS (data-per-streamline) becomes ``CellData`` arrays on the polydata. -- TRX groups are represented as a single label array per streamline - (``TRX_GroupId``), with a name table stored in ``FieldData`` as - ``TRX_GroupNames``. - -Round-trip metadata convention: - -- DPV arrays are stored as ``TRX_DPV_`` in ``PointData``. -- DPS arrays are stored as ``TRX_DPS_`` in ``CellData``. -- The storage node only exports arrays with these prefixes back into TRX, so - metadata remains recognizable and round-trippable. - -How users can visualize and interact with TRX metadata in the Slicer GUI: - -- **DPV**: can be used for per-point coloring (e.g., FA along the fiber) by - selecting the corresponding ``TRX_DPV_*`` array as the scalar to display. -- **DPS**: can be used for per-streamline coloring or thresholding by selecting - a ``TRX_DPS_*`` array in the fiber bundle display controls. -- **Groups**: color by ``TRX_GroupId`` to show group membership, and use - thresholding or selection filters to isolate specific group ids. The group - id-to-name mapping is stored in ``TRX_GroupNames`` for reference. - -Users can add their own DPV/DPS arrays in Slicer (via Python, modules, or -filters). To ensure these arrays are written back into TRX, name them with the -``TRX_DPV_`` or ``TRX_DPS_`` prefixes and keep them single-component with the -correct tuple counts (points for DPV, streamlines for DPS). - -TrxReader vs TrxFile --------------------- - -``TrxFile
`` is the core typed container. It owns the memory-mapped arrays, -exposes streamlines and metadata as Eigen matrices, and provides mutation and -save operations. The template parameter ``DT`` fixes the positions dtype -(``half``, ``float``, or ``double``), which allows zero-copy access and avoids -per-element conversions. - -``TrxReader
`` is a small RAII wrapper that loads a TRX file and manages the -backing resources. It ensures the temporary extraction directory (for zipped -TRX) is cleaned up when the reader goes out of scope, and provides safe access -to the underlying ``TrxFile``. This separation keeps ``TrxFile`` focused on the -data model, while ``TrxReader`` handles ownership and lifecycle concerns for -loaded files. - -In practice, most downstream users do not need to instantiate ``TrxReader`` -directly. The common entry points are convenience functions like -``trx::load_any`` or ``trx::with_trx_reader`` and higher-level wrappers that -return a ready-to-use ``TrxFile``. ``TrxReader`` remains available for advanced -use cases where explicit lifetime control of the backing resources is needed. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index f10b93a..c8bfb7b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,22 +1,59 @@ TRX-cpp Documentation ===================== +.. image:: https://readthedocs.org/projects/trx-cpp/badge/?version=latest + :target: https://trx-cpp.readthedocs.io/en/latest/ + :alt: Documentation Status + .. image:: https://codecov.io/gh/tee-ar-ex/trx-cpp/branch/main/graph/badge.svg :target: https://codecov.io/gh/tee-ar-ex/trx-cpp :alt: codecov -TRX-cpp is a C++ library for reading, writing, and memory-mapping the TRX -tractography file format. +A C++17 library for reading, writing, and memory-mapping the TRX tractography +format. .. toctree:: :maxdepth: 2 :caption: Getting Started - overview + introduction + quick_start building - usage + +.. toctree:: + :maxdepth: 2 + :caption: User Guide + + concepts + api_layers + reading + writing + streaming + spatial_queries + nifti + +.. toctree:: + :maxdepth: 2 + :caption: Integration Guide + + integration + +.. toctree:: + :maxdepth: 2 + :caption: Performance + benchmarks - downstream_usage + +.. toctree:: + :maxdepth: 2 + :caption: TRX Format + + spec + +.. toctree:: + :maxdepth: 2 + :caption: Contributing + linting .. toctree:: diff --git a/docs/integration.rst b/docs/integration.rst new file mode 100644 index 0000000..e462c95 --- /dev/null +++ b/docs/integration.rst @@ -0,0 +1,316 @@ +Integration Guide +================= + +This page provides worked examples for integrating trx-cpp into common +tractography frameworks. Each example shows how to map the framework's +internal streamline representation to TRX and back. + +All examples assume that coordinates are already in RAS+ world space +(millimeters). If your framework uses a different coordinate convention, +apply the appropriate affine transform before writing to TRX. A common case +is LPS+ (used by ITK-based tools such as MITK), where you negate the x and y +components to convert to RAS+. + +MRtrix3 +------- + +MRtrix3 tracks are stored as ``GeneratedTrack`` objects +(``std::vector``) produced by the tracking engine. +Coordinates are in RAS+ world space and map directly to TRX ``positions``. + +**Bulk conversion** — when all streamlines are available in memory: + +.. code-block:: cpp + + #include + #include "dwi/tractography/tracking/generated_track.h" + + using MR::DWI::Tractography::Tracking::GeneratedTrack; + + void write_trx_from_mrtrix(const std::vector& tracks, + const std::string& out_path) { + std::vector accepted; + size_t total_vertices = 0; + for (const auto& tck : tracks) { + if (tck.get_status() != GeneratedTrack::status_t::ACCEPTED) continue; + accepted.push_back(&tck); + total_vertices += tck.size(); + } + + trx::TrxFile trx(total_vertices, accepted.size()); + + auto& positions = trx.streamlines->_data; + auto& offsets = trx.streamlines->_offsets; + auto& lengths = trx.streamlines->_lengths; + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < accepted.size(); ++i) { + const auto& tck = *accepted[i]; + lengths(i) = static_cast(tck.size()); + offsets(i + 1) = offsets(i) + tck.size(); + for (const auto& pt : tck) { + positions(cursor, 0) = pt.x(); + positions(cursor, 1) = pt.y(); + positions(cursor, 2) = pt.z(); + ++cursor; + } + } + + trx.save(out_path, ZIP_CM_STORE); + trx.close(); + } + +**Streaming** — appending as each streamline is accepted, without buffering: + +.. code-block:: cpp + + #include + #include "dwi/tractography/tracking/generated_track.h" + + using MR::DWI::Tractography::Tracking::GeneratedTrack; + + trx::TrxStream trx_stream; + + void on_streamline(const GeneratedTrack& tck) { + std::vector xyz; + xyz.reserve(tck.size() * 3); + for (const auto& pt : tck) { + xyz.push_back(pt[0]); + xyz.push_back(pt[1]); + xyz.push_back(pt[2]); + } + trx_stream.push_streamline(xyz); + } + + // Call once after all streamlines are generated: + trx_stream.finalize("tracks.trx", ZIP_CM_STORE); + +DSI Studio +---------- + +DSI Studio stores tractography in ``tract_model.cpp`` as +``std::vector>`` with interleaved XYZ values. Cluster +assignments, per-streamline scalars, and along-tract scalars map cleanly to +TRX groups, DPS, and DPV respectively. + +.. code-block:: cpp + + std::vector> streamlines = /* DSI Studio tract_data */; + std::vector cluster_ids = /* one per streamline */; + + size_t total_vertices = 0; + for (const auto& sl : streamlines) total_vertices += sl.size() / 3; + + trx::TrxFile trx(total_vertices, streamlines.size()); + auto& positions = trx.streamlines->_data; + auto& offsets = trx.streamlines->_offsets; + auto& lengths = trx.streamlines->_lengths; + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < streamlines.size(); ++i) { + const auto& sl = streamlines[i]; + const size_t pts = sl.size() / 3; + lengths(i) = static_cast(pts); + offsets(i + 1) = offsets(i) + pts; + for (size_t p = 0; p < pts; ++p, ++cursor) { + positions(cursor, 0) = sl[p * 3 + 0]; + positions(cursor, 1) = sl[p * 3 + 1]; + positions(cursor, 2) = sl[p * 3 + 2]; + } + } + + std::map> clusters; + for (size_t i = 0; i < cluster_ids.size(); ++i) { + clusters[cluster_ids[i]].push_back(static_cast(i)); + } + for (const auto& [label, indices] : clusters) { + trx.add_group_from_indices("cluster_" + std::to_string(label), indices); + } + + trx.save("out.trx", ZIP_CM_STORE); + trx.close(); + +nibrary / dmriTrekker +--------------------- + +nibrary uses ``Streamline = std::vector`` and +``Tractogram = std::vector``. Coordinates are in the same world +space as MRtrix3 ``.tck`` (RAS+) and map directly to TRX ``positions``. + +**Write nibrary streamlines to TRX:** + +.. code-block:: cpp + + using NIBR::Streamline; + using NIBR::Tractogram; + + Tractogram nibr = /* nibrary tractogram */; + + size_t total_vertices = 0; + for (const auto& sl : nibr) total_vertices += sl.size(); + + trx::TrxFile trx_out(total_vertices, nibr.size()); + auto& positions = trx_out.streamlines->_data; + auto& offsets = trx_out.streamlines->_offsets; + auto& lengths = trx_out.streamlines->_lengths; + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < nibr.size(); ++i) { + const auto& sl = nibr[i]; + lengths(i) = static_cast(sl.size()); + offsets(i + 1) = offsets(i) + sl.size(); + for (size_t p = 0; p < sl.size(); ++p, ++cursor) { + positions(cursor, 0) = sl[p][0]; + positions(cursor, 1) = sl[p][1]; + positions(cursor, 2) = sl[p][2]; + } + } + + trx_out.save("tracks.trx", ZIP_CM_STORE); + trx_out.close(); + +**Read TRX into nibrary-style streamlines:** + +.. code-block:: cpp + + auto trx_in = trx::load_any("tracks.trx"); + const auto pos = trx_in.positions.as_matrix(); + const auto offs = trx_in.offsets.as_matrix(); + + Tractogram out; + out.reserve(trx_in.num_streamlines()); + for (size_t i = 0; i < trx_in.num_streamlines(); ++i) { + const size_t start = static_cast(offs(i, 0)); + const size_t end = static_cast(offs(i + 1, 0)); + Streamline sl; + sl.reserve(end - start); + for (size_t j = start; j < end; ++j) { + sl.push_back({pos(j, 0), pos(j, 1), pos(j, 2)}); + } + out.push_back(std::move(sl)); + } + + trx_in.close(); + +MITK Diffusion +-------------- + +MITK Diffusion stores streamlines as ``BundleType`` +(``std::vector``), where ``FiberType`` is +``std::deque>``. + +**Coordinate system note:** MITK/ITK uses LPS+ physical coordinates. TRX +expects RAS+. Negate the x and y components when writing to TRX, and negate +them again when reading back. + +.. code-block:: cpp + + #include + #include + #include + #include + + using FiberType = std::deque>; + using BundleType = std::vector; + + void mitk_bundle_to_trx(const BundleType& bundle, const std::string& out_path) { + size_t total_vertices = 0; + for (const auto& fiber : bundle) total_vertices += fiber.size(); + + trx::TrxFile trx(total_vertices, bundle.size()); + auto& positions = trx.streamlines->_data; + auto& offsets = trx.streamlines->_offsets; + auto& lengths = trx.streamlines->_lengths; + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < bundle.size(); ++i) { + const auto& fiber = bundle[i]; + lengths(i) = static_cast(fiber.size()); + offsets(i + 1) = offsets(i) + fiber.size(); + for (const auto& pt : fiber) { + positions(cursor, 0) = -pt[0]; // LPS -> RAS: negate x + positions(cursor, 1) = -pt[1]; // LPS -> RAS: negate y + positions(cursor, 2) = pt[2]; + ++cursor; + } + } + + trx.save(out_path, ZIP_CM_STORE); + trx.close(); + } + + BundleType trx_to_mitk_bundle(const std::string& trx_path) { + auto trx = trx::load_any(trx_path); + const auto pos = trx.positions.as_matrix(); + const auto offs = trx.offsets.as_matrix(); + + BundleType bundle; + bundle.reserve(trx.num_streamlines()); + for (size_t i = 0; i < trx.num_streamlines(); ++i) { + const size_t start = static_cast(offs(i, 0)); + const size_t end = static_cast(offs(i + 1, 0)); + FiberType fiber(end - start); + for (size_t j = start; j < end; ++j) { + fiber[j - start][0] = -pos(j, 0); // RAS -> LPS + fiber[j - start][1] = -pos(j, 1); + fiber[j - start][2] = pos(j, 2); + } + bundle.push_back(std::move(fiber)); + } + + trx.close(); + return bundle; + } + +SlicerDMRI +---------- + +SlicerDMRI represents tractography as ``vtkPolyData`` inside a +``vtkMRMLFiberBundleNode``. TRX structures map to VTK data arrays as follows: + +- TRX ``positions`` + ``offsets`` → polydata points and polyline cells. + Each streamline becomes one line cell; point coordinates are in RAS+. +- TRX DPV → ``PointData`` arrays named ``TRX_DPV_``. +- TRX DPS → ``CellData`` arrays named ``TRX_DPS_``. +- TRX groups → a per-streamline ``TRX_GroupId`` label array in ``CellData``, + with a ``TRX_GroupNames`` name table in ``FieldData``. + +On save, the storage node exports only arrays with the ``TRX_DPV_`` or +``TRX_DPS_`` prefix back to TRX, ensuring clean round-trips without +extraneous fields. + +**Visualization in the Slicer GUI:** + +- DPV arrays appear in the fiber bundle display controls for per-point + coloring (e.g., FA along the fiber). +- DPS arrays support per-streamline coloring or thresholding. +- Groups can be visualized by coloring on ``TRX_GroupId`` and using + thresholding or selection filters to isolate specific group IDs. + +ITK-SNAP +-------- + +ITK-SNAP uses LPS+ physical coordinates. TRX positions are in RAS+, so +negate the x and y components in both directions when converting. + +Streamlines can be added to slice views by implementing a renderer delegate +in the slice rendering pipeline: + +- ``GUI/Qt/Components/SliceViewPanel`` installs renderer delegates. +- ``GUI/Renderer/GenericSliceRenderer`` and ``SliceRendererDelegate`` define + the overlay API (lines, polylines). +- ``CrosshairsRenderer`` and ``PolygonDrawingRenderer`` show how to draw + line-based primitives. + +A streamline renderer delegate would: + +1. Filter streamlines intersecting the current slice plane (using cached + AABBs from :func:`trx::TrxFile::build_streamline_aabbs` for speed). +2. Project 3D RAS+ points to slice coordinates via + ``GenericSliceModel::MapImageToSlice`` (after negating x/y to convert to + LPS+). +3. Draw each trajectory with ``DrawPolyLine`` in the render context. diff --git a/docs/introduction.rst b/docs/introduction.rst new file mode 100644 index 0000000..2ac528a --- /dev/null +++ b/docs/introduction.rst @@ -0,0 +1,50 @@ +Introduction +============ + +TRX-cpp is a C++17 library for reading, writing, and memory-mapping the +`TRX tractography format `_. TRX is +a ZIP-based container for fiber tract geometry and associated metadata, +designed for large-scale diffusion MRI tractography. + +Features +-------- + +**Zero-copy memory mapping** + Streamline positions, per-vertex data (DPV), and per-streamline data (DPS) + are exposed as ``Eigen::Map`` views directly over memory-mapped files. + Accessing a 10 M-streamline dataset does not require loading the full array + into RAM. + +**Streaming writes** + :class:`trx::TrxStream` appends streamlines incrementally and finalizes to + a TRX archive or directory once tracking is complete. Suitable for + tractography pipelines where the total count is unknown at the start. + +**Spatial queries** + Build per-streamline axis-aligned bounding boxes (AABBs) and extract spatial + subsets in sub-millisecond time per query. Designed for interactive + slice-view workflows that need to filter streamlines as the user moves + through a volume. + +**Typed and type-erased APIs** + :class:`trx::TrxFile` is templated on the positions dtype for compile-time + type safety. :class:`trx::AnyTrxFile` reads the dtype from disk and + dispatches at runtime — useful when the file format is not known in advance. + +**ZIP and directory storage** + Read and write ``.trx`` zip archives and plain on-disk directories with the + same API. Directory storage is convenient for in-place access; zip storage is + convenient for distribution and transfer. + +**Optional NIfTI support** + Read qform/sform affines from ``.nii``, ``.hdr``, or ``.nii.gz`` and embed + them in the TRX header for consistent coordinate interpretation downstream. + +Where to go next +---------------- + +- :doc:`quick_start` — install the library and write a first program +- :doc:`building` — full dependency and build options reference +- :doc:`concepts` — how TRX-cpp represents streamlines and metadata internally +- :doc:`api_layers` — choosing between ``AnyTrxFile``, ``TrxFile
``, and + ``TrxReader
`` diff --git a/docs/nifti.rst b/docs/nifti.rst new file mode 100644 index 0000000..ccb284b --- /dev/null +++ b/docs/nifti.rst @@ -0,0 +1,44 @@ +NIfTI Header Support +==================== + +When trx-cpp is built with ``TRX_ENABLE_NIFTI=ON``, the optional NIfTI I/O +module can read qform/sform affines from ``.nii``, ``.hdr``, or ``.nii.gz`` +files and embed them in the TRX ``VOXEL_TO_RASMM`` header field. + +This is primarily useful when the TRX file must interoperate with the +``.trk`` (TrackVis) format, which stores coordinates in voxel space and +relies on the NIfTI header for the voxel-to-world transform. + +Attach a NIfTI affine to a TRX file +------------------------------------- + +.. code-block:: cpp + + #include + #include + + Eigen::Matrix4f affine = trx::read_nifti_voxel_to_rasmm("reference.nii.gz"); + + auto trx = trx::load("tracks.trx"); + trx->set_voxel_to_rasmm(affine); + trx->save("tracks_with_ref.trx"); + trx->close(); + +Notes +----- + +- The qform is preferred when present. If only the sform is available, it is + orthogonalized to a qform-equivalent matrix, consistent with ITK's NIfTI + handling. +- The qform/sform logic is adapted from nibabel's MIT-licensed implementation + (see ``third_party/nibabel/LICENSE``). +- zlib must be discoverable by CMake to build ``.nii.gz`` decompression + support. + +Enable at build time +--------------------- + +.. code-block:: bash + + cmake -S . -B build -DTRX_ENABLE_NIFTI=ON + cmake --build build diff --git a/docs/overview.rst b/docs/overview.rst deleted file mode 100644 index 37e783e..0000000 --- a/docs/overview.rst +++ /dev/null @@ -1,108 +0,0 @@ -Overview -======== - -TRX-cpp provides: - -- Read and write TRX archives and directories -- Memory-mapped access for large datasets -- Simple access to streamlines and metadata - -The API is header-focused under ``include/trx``, with implementation in ``src/``. - -TRX in C++ -========== - -TRX file format overview ------------------------- - -TRX is a tractography container for streamlines and auxiliary data. The core -geometry lives in two arrays: a single ``positions`` array containing all -points and an ``offsets`` array that delineates each streamline. Coordinates -are stored in **RAS+ world space (millimeters)**. This matches the coordinate -convention used by MRtrix ``.tck`` and many other tractography tools. -It also avoids the pitfalls of image-based coordinate systems. - -TRX can be stored either as: - -- A **directory** containing ``header.json`` and data files (positions, - offsets, dpv/dps/groups/dpg). -- A **zip archive** (``.trx``) with the same internal structure. - -Auxiliary data stored alongside streamlines: - -- **DPV** (data per vertex): values aligned with each point. -- **DPS** (data per streamline): values aligned with each streamline. -- **Groups**: named sets of streamline indices for labeling or clustering. -- **DPG** (data per group): values aligned with groups (one set per group). - -The ``header.json`` includes spatial metadata such as ``VOXEL_TO_RASMM`` and -``DIMENSIONS`` to preserve interpretation of coordinates. See the -`TRX specification `_ for details. - - -Positions array --------------- - -The ``positions`` array is a single contiguous matrix of shape -``(NB_VERTICES, 3)``. Storing all vertices in one array is cache-friendly and -enables efficient memory mapping. In trx-cpp, ``positions`` is backed by a -``mio::shared_mmap_sink`` and exposed as an ``Eigen::Matrix`` for zero-copy access -when possible. - -Offsets and the sentinel value ------------------------------- - -The ``offsets`` array is a prefix-sum index into ``positions``. Its length is -``NB_STREAMLINES + 1``. The final element is a **sentinel** that equals -``NB_VERTICES`` and makes length computation trivial: - -``length_i = offsets[i + 1] - offsets[i]``. - -This design avoids per-streamline allocations and supports fast slicing of the -global ``positions`` array. In trx-cpp, offsets are stored as ``uint64`` and -mapped directly into Eigen. - -Data per vertex (DPV) ---------------------- - -DPV stores a value for each vertex in ``positions``. Examples include FA values -along the tract, per-point colors, or confidence measures. DPV arrays have -shape ``(NB_VERTICES, 1)`` or ``(NB_VERTICES, N)`` for multi-component values. - -In trx-cpp, DPV fields are stored under ``dpv/`` and are memory-mapped similarly -to ``positions``. This keeps per-point metadata aligned and contiguous, which -is important for large tractograms. - -Data per streamline (DPS) -------------------------- - -DPS stores a value per streamline. Examples include streamline length, average -FA, or per-tract weights. DPS arrays have shape ``(NB_STREAMLINES, 1)`` or -``(NB_STREAMLINES, N)``. - -In trx-cpp, DPS fields live under ``dps/`` and are mapped into ``MMappedMatrix`` -objects, enabling efficient access without loading entire arrays into RAM. - -Groups and data per group (DPG) -------------------------------- - -Groups provide a sparse, overlapping labeling of streamlines. Each group is a -named list of streamline indices, and a streamline can belong to multiple -groups. Examples: - -- **Bundle labels** (e.g., ``CST_L``, ``CST_R``, ``CC``) -- **Clusters** from quickbundles or similar algorithms -- **Connectivity subsets** (e.g., streamlines that connect two ROIs) - -DPG (data per group) attaches metadata to each group. Examples: - -- Mean FA for each bundle -- A per-group color or display weight -- Scalar summaries computed over the group - -In the TRX on-disk layout, groups are stored under ``groups/`` as index arrays, -and DPG data is stored under ``dpg//`` as one or more arrays. - -In trx-cpp, groups are represented as ``MMappedMatrix`` objects and -DPG fields are stored as ``MMappedMatrix
`` entries. This keeps group data -as memory-mapped arrays so large group sets can be accessed without copying. diff --git a/docs/quick_start.rst b/docs/quick_start.rst new file mode 100644 index 0000000..6dd3eee --- /dev/null +++ b/docs/quick_start.rst @@ -0,0 +1,98 @@ +Quick Start +=========== + +This page walks through getting trx-cpp installed and writing a minimal +program that loads a TRX file, prints basic statistics, and saves a copy. + +Prerequisites +------------- + +- A C++17 compiler (GCC ≥ 7, Clang ≥ 5, MSVC 2019+) +- CMake ≥ 3.14 +- `libzip `_ +- `Eigen 3.4+ `_ + +See :doc:`building` for platform-specific installation commands. + +Install trx-cpp +--------------- + +Build and install the library so it can be found by ``find_package``: + +.. code-block:: bash + + git clone https://github.com/tee-ar-ex/trx-cpp.git + cmake -S trx-cpp -B trx-cpp/build \ + -DCMAKE_BUILD_TYPE=Release \ + -DTRX_BUILD_EXAMPLES=OFF \ + -DTRX_ENABLE_INSTALL=ON \ + -DCMAKE_INSTALL_PREFIX=$HOME/.local + cmake --build trx-cpp/build --config Release + cmake --install trx-cpp/build + +Alternatively, add trx-cpp as a subdirectory in your project (no install step +needed): + +.. code-block:: cmake + + add_subdirectory(path/to/trx-cpp) + target_link_libraries(my_app PRIVATE trx-cpp::trx) + +Write a first program +--------------------- + +Create a ``CMakeLists.txt`` and a ``main.cpp``: + +.. code-block:: cmake + + # CMakeLists.txt + cmake_minimum_required(VERSION 3.14) + project(hello_trx) + + find_package(trx-cpp CONFIG REQUIRED) + + add_executable(hello_trx main.cpp) + target_link_libraries(hello_trx PRIVATE trx-cpp::trx) + +.. code-block:: cpp + + // main.cpp + #include + #include + + int main(int argc, char* argv[]) { + if (argc < 2) { + std::cerr << "usage: hello_trx \n"; + return 1; + } + + auto trx = trx::load_any(argv[1]); + + std::cout << "streamlines : " << trx.num_streamlines() << "\n"; + std::cout << "vertices : " << trx.num_vertices() << "\n"; + std::cout << "dtype : " << trx.positions.dtype << "\n"; + + for (const auto& [name, arr] : trx.data_per_streamline) { + std::cout << "dps/" << name + << " (" << arr.rows() << " x " << arr.cols() << ")\n"; + } + + trx.close(); + return 0; + } + +Build and run: + +.. code-block:: bash + + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release + cmake --build build + ./build/hello_trx /path/to/tracks.trx + +Next steps +---------- + +- :doc:`reading` — access streamline positions and metadata +- :doc:`writing` — create and save TRX files +- :doc:`streaming` — stream streamlines without buffering the full dataset +- :doc:`api_layers` — understand the three API layers diff --git a/docs/reading.rst b/docs/reading.rst new file mode 100644 index 0000000..8ef2345 --- /dev/null +++ b/docs/reading.rst @@ -0,0 +1,146 @@ +Reading TRX Files +================= + +This page covers the common patterns for loading and inspecting TRX data. +See :doc:`api_layers` for guidance on choosing between ``AnyTrxFile`` and +``TrxFile
``. + +Load and inspect +---------------- + +The simplest entry point is :func:`trx::load_any`, which detects the dtype +from the file and returns an :class:`trx::AnyTrxFile`: + +.. code-block:: cpp + + #include + + auto trx = trx::load_any("/path/to/tracks.trx"); + + std::cout << "dtype : " << trx.positions.dtype << "\n"; + std::cout << "streamlines : " << trx.num_streamlines() << "\n"; + std::cout << "vertices : " << trx.num_vertices() << "\n"; + + trx.close(); + +Access positions and offsets +----------------------------- + +Positions and offsets are exposed as :class:`trx::TypedArray` objects. Call +``as_matrix()`` to obtain an ``Eigen::Map`` view without copying data: + +.. code-block:: cpp + + auto pos = trx.positions.as_matrix(); // (NB_VERTICES, 3) + auto offs = trx.offsets.as_matrix(); // (NB_STREAMLINES + 1, 1) + + for (size_t i = 0; i < trx.num_streamlines(); ++i) { + const size_t start = static_cast(offs(i, 0)); + const size_t end = static_cast(offs(i + 1, 0)); + // vertices for streamline i: pos.block(start, 0, end - start, 3) + } + +Access DPV and DPS +------------------ + +Per-vertex (DPV) and per-streamline (DPS) metadata are stored in +``std::map`` containers: + +.. code-block:: cpp + + // List all DPS fields + for (const auto& [name, arr] : trx.data_per_streamline) { + std::cout << "dps/" << name + << " (" << arr.rows() << " x " << arr.cols() << ")\n"; + } + + // Access a specific DPV field + auto fa = trx.data_per_vertex.at("fa").as_matrix(); // (NB_VERTICES, 1) + +Access groups +------------- + +.. code-block:: cpp + + for (const auto& [name, indices] : trx.groups) { + std::cout << "group " << name << ": " + << indices.size() << " streamlines\n"; + } + +Typed access via TrxFile
+---------------------------- + +When the dtype is known ahead of time, use :func:`trx::load` for a typed +view. Positions and DPV arrays are exposed as ``Eigen::Matrix`` +directly, avoiding element-wise conversion: + +.. code-block:: cpp + + auto reader = trx::load("tracks.trx"); + auto& trx = *reader; + + // trx.streamlines->_data is Eigen::Matrix + // trx.streamlines->_offsets is Eigen::Matrix + + reader->close(); + +Iterating streamlines without copying +-------------------------------------- + +Because TRX positions are memory-mapped, the full positions array is never +read into RAM at once — the OS pages in only the regions you touch. You do +not need a separate "streaming reader" to process a 10 M-streamline file +without exhausting memory. + +To iterate over each streamline with zero per-streamline allocation, use +:func:`trx::TrxFile::for_each_streamline`. The callback receives the +streamline index, the start row in ``_data``, and the number of vertices: + +.. code-block:: cpp + + auto reader = trx::load("tracks.trx"); + auto& trx = *reader; + + trx.for_each_streamline([&](size_t idx, uint64_t start, uint64_t length) { + // Zero-copy block view of this streamline's vertices. + auto pts = trx.streamlines->_data.block( + static_cast(start), 0, + static_cast(length), 3); + + // pts is an Eigen expression — no heap allocation. + // Example: compute the centroid. + Eigen::Vector3f centroid = pts.colwise().mean(); + }); + + reader->close(); + +For random access to a single streamline, use +:func:`trx::TrxFile::get_streamline`. This copies the vertices into a +``std::vector`` and is convenient for one-off lookups, but avoid it in +tight loops over large tractograms: + +.. code-block:: cpp + + auto pts = trx.get_streamline(42); // std::vector> + +Chunk-based iteration (AnyTrxFile) +------------------------------------ + +:class:`trx::AnyTrxFile` provides ``for_each_positions_chunk``, which +iterates the positions buffer in fixed-size byte chunks. This is useful +for transcoding or checksum passes that process all vertices but do not +need the per-streamline boundary structure: + +.. code-block:: cpp + + auto trx = trx::load_any("tracks.trx"); + + trx.for_each_positions_chunk( + 4 * 1024 * 1024, // 4 MB chunks + [](trx::TrxScalarType dtype, const void* data, + size_t point_offset, size_t point_count) { + // data points to point_count * 3 values of the given dtype, + // starting at global vertex index point_offset. + }); + + trx.close();; diff --git a/docs/spatial_queries.rst b/docs/spatial_queries.rst new file mode 100644 index 0000000..c3e85e9 --- /dev/null +++ b/docs/spatial_queries.rst @@ -0,0 +1,61 @@ +Spatial Queries +=============== + +trx-cpp can build per-streamline axis-aligned bounding boxes (AABBs) and +use them to extract spatial subsets efficiently. This is useful for +interactive slice-view updates or region-of-interest filtering. + +Query by bounding box +--------------------- + +Pass minimum and maximum corners in RAS+ world coordinates (mm): + +.. code-block:: cpp + + #include + + auto trx = trx::load("/path/to/tracks.trx"); + + std::array min_corner{-10.0f, -10.0f, -10.0f}; + std::array max_corner{ 10.0f, 10.0f, 10.0f}; + + auto subset = trx->query_aabb(min_corner, max_corner); + subset->save("subset.trx", ZIP_CM_STORE); + subset->close(); + +Precompute the AABB cache +-------------------------- + +When issuing multiple spatial queries on the same file — for example, as a +user scrubs through slices in a viewer — precompute the AABB cache once and +pass it to each query: + +.. code-block:: cpp + + auto aabbs = trx->build_streamline_aabbs(); + + // Query 1 + auto s1 = trx->query_aabb(min1, max1, &aabbs); + + // Query 2 — reuses the same cached bounding boxes + auto s2 = trx->query_aabb(min2, max2, &aabbs); + + // Optionally build the AABB cache for the result as well + auto s3 = trx->query_aabb(min3, max3, &aabbs, /*build_aabbs_for_result=*/true); + +AABBs are stored in ``float16`` for memory efficiency. Comparisons are +performed in ``float32`` to avoid precision issues at the boundary. + +Subset by streamline IDs +------------------------- + +If you have a list of streamline indices from a prior step (clustering, +spatial query, manual selection), create a subset directly: + +.. code-block:: cpp + + std::vector ids{0, 4, 42, 99}; + auto subset = trx->subset_streamlines(ids); + + subset->save("subset_by_id.trx", ZIP_CM_STORE); + subset->close(); diff --git a/docs/spec.md b/docs/spec.md deleted file mode 100644 index 56c122f..0000000 --- a/docs/spec.md +++ /dev/null @@ -1,152 +0,0 @@ - -# Generals -- (Un)-Compressed Zip File or simple folder architecture - - File architecture describe the data - - Each file basename is the metadata’s name - - Each file extension is the metadata’s dtype - - Each file dimension is in the value between basename and metdata, 1-dimension array do not have to follow this convention for readability -- All arrays have a C-style memory layout(row-major) -- All arrays have a little-endian byte order -- Compression is optional - - use ZIP_STORE, if compression is desired use ZIP_DEFLATE - - Compressed TRX files will have to be decompressed before being loaded - -# Header -Only (or mostly) for use-readability, read-time checks and broader compatibility. - -- Dictionary in JSON - - VOXEL_TO_RASMM (4 lists of 4 float, 4x4 transformation matrix) - - DIMENSIONS (list of 3 uint16) - - NB_STREAMLINES (uint32) - - NB_VERTICES (uint64) - -# Arrays -# positions.float16 -- Written in world space(RASMM) - - Like TCK file -- Should always be a float16/32/64 - - Default could be float16 -- As contiguous 3D array(NB_VERTICES, 3) - -# offsets.uint64 -- Should always be a uint32/64 -- Where is the first vertex of each streamline, start at 0 -- Two ways of knowing how many vertices there are: - - Check the header - - Positions array size / dtypes / 3 - -- To get streamlines lengths: append the total number of vertices to the end of offsets and to the differences between consecutive elements of the array(ediff1d in numpy). - -# dpv (data_per_vertex) -- Always of size (NB_VERTICES, 1) or (NB_VERTICES, N) - -# dps (data_per_streamline) -- Always of size (NB_STREAMLINES, 1) or (NB_STREAMLINES, N) - -# Groups -Groups are tables of indices that allow sparse & overlapping representation(clusters, connectomics, bundles). -- All indices must be 0 < id < NB_STREAMLINES -- Datatype should be uint32 -- Allow to get a predefined streamlines subset from the memmaps efficiently -- Variables in sizes - -# dpg (data_per_group) -- Each folder is the name of a group -- Not all metadata have to be present in all groups -- Always of size (1,) or (N,) - -# Accepted extensions (datatype) -- int8/16/32/64 -- uint8/16/32/64 -- float16/32/64 - -# Example structure -```bash -OHBM_demo.trx -├── dpg -│ ├── AF_L -│ │ ├── mean_fa.float16 -│ │ ├── shuffle_colors.3.uint8 -│ │ └── volume.uint32 -│ ├── AF_R -│ │ ├── mean_fa.float16 -│ │ ├── shuffle_colors.3.uint8 -│ │ └── volume.uint32 -│ ├── CC -│ │ ├── mean_fa.float16 -│ │ ├── shuffle_colors.3.uint8 -│ │ └── volume.uint32 -│ ├── CST_L -│ │ └── shuffle_colors.3.uint8 -│ ├── CST_R -│ │ └── shuffle_colors.3.uint8 -│ ├── SLF_L -│ │ ├── mean_fa.float16 -│ │ ├── shuffle_colors.3.uint8 -│ │ └── volume.uint32 -│ └── SLF_R -│ ├── mean_fa.float16 -│ ├── shuffle_colors.3.uint8 -│ └── volume.uint32 -├── dpv -│ ├── color_x.uint8 -│ ├── color_y.uint8 -│ ├── color_z.uint8 -│ └── fa.float16 -├── dps -│ ├── algo.uint8 -│ ├── algo.json -│ ├── clusters_QB.uint16 -│ ├── commit_colors.3.uint8 -│ └── commit_weights.float32 -├── groups -│ ├── AF_L.uint32 -│ ├── AF_R.uint32 -│ ├── CC.uint32 -│ ├── CST_L.uint32 -│ ├── CST_R.uint32 -│ ├── SLF_L.uint32 -│ └── SLF_R.uint32 -├── header.json -├── offsets.uint64 -└── positions.3.float16 -``` - -# Example code -```python -from trx_file_memmap import TrxFile, load, save -import numpy as np - -trx = load('complete_big_v5.trx') - -# Access the header (dict) / streamlines (ArraySequences) -trx.header -trx.streamlines - -# Access the dpv (dict) / dps (dict) -trx.data_per_vertex -trx.data_per_streamline - -# Access the groups (dict) / dpg (dict) -trx.groups -trx.data_per_group - -# Get a random subset of 10000 streamlines -indices = np.arange(len(trx.streamlines._lengths)) -np.random.shuffle(indices) -sub_trx = trx.select(indices[0:10000]) -save(sub_trx, 'random_1000.trx') - -# Get sub-groups only, from the random subset -for key in sub_trx.groups.keys(): - group_trx = sub_trx.get_group(key) - save(group_trx, '{}.trx'.format(key)) - -# Pre-allocate memmaps and append 100x the random subset -alloc_trx = TrxFile(nb_streamlines=1500000, nb_vertices=500000000, init_as=trx) -for i in range(100): - alloc_trx.append(sub_trx) - -# Resize to remove the unused portion of the memmap -alloc_trx.resize() -``` diff --git a/docs/spec.rst b/docs/spec.rst new file mode 100644 index 0000000..2ea3ca6 --- /dev/null +++ b/docs/spec.rst @@ -0,0 +1,130 @@ +TRX Format Specification +======================== + +This page documents the on-disk layout and data model of the TRX +tractography format. TRX-cpp is an implementation of this specification; +the authoritative specification repository is at +https://github.com/tee-ar-ex/trx-spec. + +General layout +-------------- + +A TRX file is either an **uncompressed or compressed ZIP archive**, or a +**plain directory**. In both cases the internal structure is identical: the +file hierarchy describes the data, and filename components encode metadata. + +- Each file's **basename** is the name of the metadata field. +- Each file's **extension** is the dtype (e.g., ``float16``, ``uint32``). +- Multi-component arrays encode the number of components between the basename + and the extension (e.g., ``positions.3.float16`` has 3 components per row). + Single-component arrays may omit this field for readability. + +All arrays use **C-style (row-major) memory layout** and **little-endian +byte order**. + +Compression is optional. Use ``ZIP_STORE`` for uncompressed storage; use +``ZIP_DEFLATE`` if compression is desired. Compressed files must be +decompressed before memory-mapping. + +Header +------ + +``header.json`` is a JSON dictionary with the following fields: + ++------------------+----------------------------+---------------------------+ +| Field | Type | Description | ++==================+============================+===========================+ +| VOXEL_TO_RASMM | 4×4 array of float | Affine from voxel to RAS+ | ++------------------+----------------------------+---------------------------+ +| DIMENSIONS | list of 3 uint16 | Reference image grid size | ++------------------+----------------------------+---------------------------+ +| NB_STREAMLINES | uint32 | Number of streamlines | ++------------------+----------------------------+---------------------------+ +| NB_VERTICES | uint64 | Total number of vertices | ++------------------+----------------------------+---------------------------+ + +The header is primarily for human readability and downstream compatibility +checks. The authoritative array sizes come from the data arrays themselves. + +Arrays +------ + +**positions** (``positions.float16`` / ``positions.float32`` / ``positions.float64``) + All streamline vertices as a contiguous C array of shape ``(NB_VERTICES, 3)``. + Stored in **RAS+ world space (millimeters)**, matching the MRtrix3 ``.tck`` + convention. + +**offsets** (``offsets.uint32`` / ``offsets.uint64``) + Prefix-sum index of length ``NB_STREAMLINES + 1``. Element *i* is the index + in ``positions`` of the first vertex of streamline *i*. The final element + is a sentinel equal to ``NB_VERTICES``. + + Streamline length: ``length_i = offsets[i+1] - offsets[i]``. + +**dpv — data per vertex** (``dpv/.``) + Shape ``(NB_VERTICES, 1)`` or ``(NB_VERTICES, N)``. Values are aligned + with ``positions`` row-by-row. + +**dps — data per streamline** (``dps/.``) + Shape ``(NB_STREAMLINES, 1)`` or ``(NB_STREAMLINES, N)``. Values are + aligned with streamlines. + +**groups** (``groups/.uint32``) + Variable-length index arrays. Each file lists the 0-based indices of + streamlines belonging to the named group. All indices must satisfy + ``0 ≤ id < NB_STREAMLINES``. Groups are non-exclusive: a streamline may + appear in multiple groups. + +**dpg — data per group** (``dpg//.``) + Shape ``(1,)`` or ``(N,)``. Each subdirectory corresponds to one group. + Not all metadata fields need to be present in every group. + +Accepted dtypes +--------------- + ++----------+----------+----------+ +| Signed | Unsigned | Float | ++==========+==========+==========+ +| int8 | uint8 | float16 | ++----------+----------+----------+ +| int16 | uint16 | float32 | ++----------+----------+----------+ +| int32 | uint32 | float64 | ++----------+----------+----------+ +| int64 | uint64 | | ++----------+----------+----------+ + +Example structure +----------------- + +.. code-block:: text + + OHBM_demo.trx + ├── dpg/ + │ ├── AF_L/ + │ │ ├── mean_fa.float16 + │ │ ├── shuffle_colors.3.uint8 + │ │ └── volume.uint32 + │ ├── AF_R/ CC/ CST_L/ CST_R/ SLF_L/ SLF_R/ + ├── dpv/ + │ ├── color_x.uint8 + │ ├── color_y.uint8 + │ ├── color_z.uint8 + │ └── fa.float16 + ├── dps/ + │ ├── algo.uint8 + │ ├── algo.json + │ ├── clusters_QB.uint16 + │ ├── commit_colors.3.uint8 + │ └── commit_weights.float32 + ├── groups/ + │ ├── AF_L.uint32 + │ ├── AF_R.uint32 + │ ├── CC.uint32 + │ ├── CST_L.uint32 + │ ├── CST_R.uint32 + │ ├── SLF_L.uint32 + │ └── SLF_R.uint32 + ├── header.json + ├── offsets.uint64 + └── positions.3.float16 diff --git a/docs/streaming.rst b/docs/streaming.rst new file mode 100644 index 0000000..b1e9736 --- /dev/null +++ b/docs/streaming.rst @@ -0,0 +1,166 @@ +Streaming Writes +================ + +:class:`trx::TrxStream` is an append-only writer for cases where the total +streamline count is not known ahead of time. It writes to temporary files and +finalizes to a standard TRX archive or directory when complete. + +.. note:: + + ``TrxStream`` is **not** thread-safe for concurrent writes. Use a single + writer thread (or the main thread) to append to the stream, while other + threads generate streamlines and deliver them via a queue. + +Single-threaded streaming +-------------------------- + +.. code-block:: cpp + + #include + + trx::TrxStream stream("float16"); + + for (/* each generated streamline */) { + std::vector> points = /* ... */; + stream.push_streamline(points); + } + + stream.finalize("tracks.trx", ZIP_CM_STORE); + +Multi-threaded producer / single writer +---------------------------------------- + +Worker threads generate streamlines and push batches into a queue. A +dedicated writer thread owns the ``TrxStream`` and consumes from the queue. + +.. code-block:: cpp + + #include + #include + #include + #include + #include + + struct Batch { + std::vector>> streamlines; + }; + + std::mutex mtx; + std::condition_variable cv; + std::queue q; + bool done = false; + + // Producer: generates streamlines, pushes batches into the queue. + auto producer = [&]() { + Batch batch; + batch.streamlines.reserve(1000); + for (int i = 0; i < 1000; ++i) { + batch.streamlines.push_back(/* generate points */); + } + { + std::lock_guard lock(mtx); + q.push(std::move(batch)); + } + cv.notify_one(); + }; + + // Writer: owns TrxStream, appends batches from the queue. + trx::TrxStream stream("float16"); + auto writer = [&]() { + for (;;) { + std::unique_lock lock(mtx); + cv.wait(lock, [&] { return done || !q.empty(); }); + if (q.empty() && done) return; + Batch batch = std::move(q.front()); + q.pop(); + lock.unlock(); + for (const auto& pts : batch.streamlines) { + stream.push_streamline(pts); + } + } + }; + + std::thread writer_thread(writer); + std::thread t1(producer), t2(producer); + t1.join(); t2.join(); + { std::lock_guard lock(mtx); done = true; } + cv.notify_all(); + writer_thread.join(); + + stream.finalize("tracks.trx", ZIP_CM_STORE); + +MRtrix-style write kernel +-------------------------- + +MRtrix3 uses a multi-threaded producer stage and a single-writer kernel to +serialize output. The same pattern works with TRX by encapsulating the +``TrxStream`` inside a kernel object: + +.. code-block:: cpp + + struct TrxWriteKernel { + explicit TrxWriteKernel(const std::string& path) + : stream("float16"), out_path(path) {} + + void operator()(const std::vector>>& batch) { + for (const auto& pts : batch) { + stream.push_streamline(pts); + } + } + + void finalize() { + stream.finalize(out_path, ZIP_CM_STORE); + } + + private: + trx::TrxStream stream; + std::string out_path; + }; + +The key rule is: **only the writer thread touches ``TrxStream``**, while +worker threads only generate streamlines. + +Process-based sharding +----------------------- + +For very large tractograms generated in parallel processes, each process can +write to a shard directory and a parent process merges the shards afterward. + +``TrxStream`` provides two finalization methods for directory output: + +- ``finalize_directory()`` — removes any existing directory before writing. + Safe for single-process workflows where you control the full lifecycle. +- ``finalize_directory_persistent()`` — does **not** remove existing + directories. Required when a parent process pre-creates the output + directory. + +Recommended multiprocess pattern: + +1. **Parent** pre-creates shard directories. +2. Each **child** calls ``finalize_directory_persistent()`` after appending + all streamlines. +3. Child writes a sentinel file (e.g., ``SHARD_OK``) to signal completion. +4. **Parent** waits for all sentinels, then merges shards. + +.. code-block:: cpp + + // Parent: pre-create shard directories + for (size_t i = 0; i < num_shards; ++i) { + std::filesystem::create_directories("shards/shard_" + std::to_string(i)); + } + + // Child: write shard and signal completion + trx::TrxStream stream("float16"); + // ... push_streamline calls ... + stream.finalize_directory_persistent("/path/to/shards/shard_0"); + std::ofstream ok("/path/to/shards/shard_0/SHARD_OK"); + ok << "ok\n"; + + // Parent: merge shards after all SHARD_OK files are present. + // See bench/bench_trx_stream.cpp for a reference merge implementation. + +.. note:: + Use ``finalize_directory()`` for single-process writes where you want a + clean output state. Use ``finalize_directory_persistent()`` for + multiprocess workflows to avoid removing directories that may be checked + for existence by other processes. diff --git a/docs/usage.rst b/docs/usage.rst deleted file mode 100644 index 04b2619..0000000 --- a/docs/usage.rst +++ /dev/null @@ -1,290 +0,0 @@ -Usage -===== - -AnyTrxFile vs TrxFile ---------------------- - -``AnyTrxFile`` is the runtime-typed API. It reads the dtype from the file and -exposes arrays as ``TypedArray`` with a ``dtype`` string. This is the simplest -entry point when you only have a TRX path. - -``TrxFile
`` is the typed API. It is templated on the positions dtype -(``half``, ``float``, or ``double``) and maps data directly into Eigen matrices of -that type. It provides stronger compile-time guarantees but requires knowing -the dtype at compile time or doing manual dispatch. The recommended typed entry -point is :func:`trx::with_trx_reader`, which performs dtype detection and -dispatches to the matching ``TrxReader
``. - -See the API reference for details: :class:`trx::AnyTrxFile` and -:class:`trx::TrxFile`. - -Read a TRX zip and inspect data -------------------------------- - -.. code-block:: cpp - - #include - - using namespace trx; - - const std::string path = "/path/to/tracks.trx"; - - auto trx = load_any(path); - - std::cout << "dtype: " << trx.positions.dtype << "\n"; - std::cout << "Vertices: " << trx.num_vertices() << "\n"; - std::cout << "Streamlines: " << trx.num_streamlines() << "\n"; - - trx.close(); - - -Write a TRX file ----------------- - -.. code-block:: cpp - - auto trx = load_any("tracks.trx"); - auto header_obj = trx.header.object_items(); - header_obj["COMMENT"] = "saved by trx-cpp"; - trx.header = json(header_obj); - - trx.save("tracks_copy.trx", ZIP_CM_STORE); - trx.close(); - -Thread-safe streaming pattern ------------------------------ - -``TrxStream`` is **not** thread-safe for concurrent writes. A common pattern for -multi-core streamline generation is to use worker threads for generation and a -single writer thread (or the main thread) to append to ``TrxStream``. - -.. code-block:: cpp - - #include - #include - #include - #include - #include - - struct Batch { - std::vector>> streamlines; - }; - - std::mutex mutex; - std::condition_variable cv; - std::queue queue; - bool done = false; - - // Worker threads: generate streamlines and push batches. - auto producer = [&]() { - Batch batch; - batch.streamlines.reserve(1000); - for (int i = 0; i < 1000; ++i) { - std::vector> points = {/* ... generate ... */}; - batch.streamlines.push_back(std::move(points)); - } - { - std::lock_guard lock(mutex); - queue.push(std::move(batch)); - } - cv.notify_one(); - }; - - // Writer thread (single): pop batches and push into TrxStream. - trx::TrxStream stream("float16"); - auto consumer = [&]() { - for (;;) { - std::unique_lock lock(mutex); - cv.wait(lock, [&]() { return done || !queue.empty(); }); - if (queue.empty() && done) { - return; - } - Batch batch = std::move(queue.front()); - queue.pop(); - lock.unlock(); - - for (const auto &points : batch.streamlines) { - stream.push_streamline(points); - } - } - }; - - std::thread writer(consumer); - std::thread t1(producer); - std::thread t2(producer); - t1.join(); - t2.join(); - { - std::lock_guard lock(mutex); - done = true; - } - cv.notify_all(); - writer.join(); - - stream.finalize("tracks.trx", ZIP_CM_STORE); - -Process-based sharding and merge --------------------------------- - -For large tractograms it is common to generate streamlines in separate -processes, write shard outputs, and merge them later. ``TrxStream`` provides two -finalization methods for directory output: - -- ``finalize_directory()`` — Single-process variant that removes any existing - directory before writing. Use when you control the entire lifecycle. - -- ``finalize_directory_persistent()`` — Multiprocess-safe variant that does NOT - remove existing directories. Use when coordinating parallel writes where a - parent process may pre-create output directories. - -Recommended multiprocess pattern: - -1. **Parent** pre-creates shard directories to validate filesystem writability. -2. Each **child process** writes a directory shard using - ``finalize_directory_persistent()``. -3. After finalization completes, child writes a sentinel file (e.g., ``SHARD_OK``) - to signal completion. -4. **Parent** waits for all ``SHARD_OK`` markers before merging shards. - -This pattern avoids race conditions where the parent checks for directory -existence while children are still writing. - -.. code-block:: cpp - - // Parent process: pre-create shard directories - for (size_t i = 0; i < num_shards; ++i) { - const std::string shard_path = "shards/shard_" + std::to_string(i); - std::filesystem::create_directories(shard_path); - } - - // Fork child processes... - -.. code-block:: cpp - - // Child process: write to pre-created directory - trx::TrxStream stream("float16"); - // ... push streamlines, dpv, dps, groups ... - stream.finalize_directory_persistent("/path/to/shards/shard_0"); - - // Signal completion to parent - std::ofstream ok("/path/to/shards/shard_0/SHARD_OK"); - ok << "ok\n"; - ok.close(); - -.. code-block:: cpp - - // Parent process (after waiting for all SHARD_OK markers) - // Merge by concatenating positions/DPV/DPS, adjusting offsets/groups. - // See bench/bench_trx_stream.cpp for a reference merge implementation. - -.. note:: - Use ``finalize_directory()`` for single-process writes where you want to - ensure a clean output state. Use ``finalize_directory_persistent()`` for - multiprocess workflows to avoid removing directories that may be checked - for existence by other processes. - -MRtrix-style write kernel (single-writer) ------------------------------------------ - -MRtrix uses a multi-threaded producer stage and a single-writer kernel to -serialize streamlines to disk. The same pattern works for TRX by letting the -writer own the ``TrxStream`` and accepting batches from the thread queue. - -.. code-block:: cpp - - #include - #include - #include - - struct TrxWriteKernel { - explicit TrxWriteKernel(const std::string &path) - : stream("float16"), out_path(path) {} - - void operator()(const std::vector>> &batch) { - for (const auto &points : batch) { - stream.push_streamline(points); - } - } - - void finalize() { - stream.finalize(out_path, ZIP_CM_STORE); - } - - private: - trx::TrxStream stream; - std::string out_path; - }; - -This kernel can be used as the final stage of a producer pipeline. The key rule -is: **only the writer thread touches ``TrxStream``**, while worker threads only -generate streamlines. - -Optional NIfTI header support ------------------------------ - -If downstream software will have to interact with ``trk`` format data, the NIfTI header -is going to be essential to go back and forth between ``trk`` and ``trx``. - -When built with ``TRX_ENABLE_NIFTI=ON``, or by default if you're building the examples, -you can read a NIfTI header (``.nii``, ``.hdr``, optionally ``.nii.gz``) and populate -``VOXEL_TO_RASMM`` in a TRX header. The qform is preferred; if missing, the -sform is orthogonalized to a qform-equivalent matrix (consistent with ITK's handling of nifti). -If using this feature, you will also need zlib available. The qform/sform logic here is -translated from nibabel's MIT-licensed implementation (see ``third_party/nibabel/LICENSE``). - -.. code-block:: cpp - - #include - #include - - Eigen::Matrix4f affine = trx::read_nifti_voxel_to_rasmm("reference.nii.gz"); - auto trx = trx::load("tracks.trx"); - trx->set_voxel_to_rasmm(affine); - trx->save("tracks_with_ref.trx"); - trx->close(); - -Build and query AABBs ---------------------- - -TRX can build per-streamline axis-aligned bounding boxes (AABB) and use them to -extract a subset of streamlines intersecting a rectangular region. AABBs are -stored in ``float16`` for memory efficiency, while comparisons are done in -``float32``. - -.. code-block:: cpp - - #include - - auto trx = trx::load("/path/to/tracks.trx"); - - // Query an axis-aligned box (min/max corners in RAS+ world coordinates). - std::array min_corner{-10.0f, -10.0f, -10.0f}; - std::array max_corner{10.0f, 10.0f, 10.0f}; - - auto subset = trx->query_aabb(min_corner, max_corner); - // Or precompute and pass the AABB cache explicitly: - // auto aabbs = trx->build_streamline_aabbs(); - // auto subset = trx->query_aabb(min_corner, max_corner, &aabbs); - // Optionally build cache for the result: - // auto subset = trx->query_aabb(min_corner, max_corner, &aabbs, true); - subset->save("subset.trx", ZIP_CM_STORE); - subset->close(); - -Subset by streamline IDs ------------------------- - -If you already have a list of streamline indices (for example, from a clustering -step or a spatial query), you can create a new TrxFile directly from those -indices. - -.. code-block:: cpp - - #include - - auto trx = trx::load("/path/to/tracks.trx"); - - std::vector ids{0, 4, 42, 99}; - auto subset = trx->subset_streamlines(ids); - - subset->save("subset_by_id.trx", ZIP_CM_STORE); - subset->close(); diff --git a/docs/writing.rst b/docs/writing.rst new file mode 100644 index 0000000..07829fd --- /dev/null +++ b/docs/writing.rst @@ -0,0 +1,71 @@ +Writing TRX Files +================= + +This page covers creating TRX files from scratch and saving loaded files. +For append-only streaming writes when the total count is not known ahead of +time, see :doc:`streaming`. + +Create and save a TRX file +-------------------------- + +Allocate a :class:`trx::TrxFile` with the desired number of vertices and +streamlines, fill the positions and offsets arrays, then call ``save``: + +.. code-block:: cpp + + #include + + const size_t nb_vertices = 500; + const size_t nb_streamlines = 10; + + trx::TrxFile trx(nb_vertices, nb_streamlines); + + auto& positions = trx.streamlines->_data; // (NB_VERTICES, 3) + auto& offsets = trx.streamlines->_offsets; // (NB_STREAMLINES + 1, 1) + auto& lengths = trx.streamlines->_lengths; // (NB_STREAMLINES, 1) + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < nb_streamlines; ++i) { + const size_t len = 50; // 50 vertices per streamline in this example + lengths(i) = static_cast(len); + offsets(i + 1) = offsets(i) + len; + for (size_t j = 0; j < len; ++j, ++cursor) { + positions(cursor, 0) = /* x */; + positions(cursor, 1) = /* y */; + positions(cursor, 2) = /* z */; + } + } + + trx.save("tracks.trx", ZIP_CM_STORE); + trx.close(); + +Pass ``ZIP_CM_DEFLATE`` instead of ``ZIP_CM_STORE`` to enable compression. +Compression reduces file size at the cost of slower read/write throughput; +``ZIP_CM_STORE`` is preferred for large files accessed over fast storage. + +Modify and re-save a loaded file +--------------------------------- + +.. code-block:: cpp + + auto trx = trx::load_any("tracks.trx"); + + // Add or update a header field. + auto header_obj = trx.header.object_items(); + header_obj["COMMENT"] = "processed by my_tool"; + trx.header = json11::Json(header_obj); + + trx.save("tracks_annotated.trx", ZIP_CM_STORE); + trx.close(); + +Save as a directory +------------------- + +Pass a directory path (without a ``.trx`` extension) to write an unzipped +TRX directory instead of a zip archive. Directory output avoids ZIP overhead +and is faster for large files on spinning disks: + +.. code-block:: cpp + + trx.save("/path/to/output_dir"); diff --git a/examples/load_trx.cpp b/examples/load_trx.cpp index 1b7fec2..ef72aa5 100644 --- a/examples/load_trx.cpp +++ b/examples/load_trx.cpp @@ -8,23 +8,23 @@ int main(int argc, char **argv) { // check_syntax off auto trx = trx::TrxFile::load_from_zip(argv[1]); - std::cout << "Vertices: " << trx->streamlines->_data.size() / 3 << "\n"; - std::cout << "First vertex (x,y,z): " << trx->streamlines->_data(0, 0) << "," << trx->streamlines->_data(0, 1) << "," - << trx->streamlines->_data(0, 2) << "\n"; - std::cout << "Streamlines: " << trx->streamlines->_offsets.size() << "\n"; - std::cout << "Vertices in first streamline: " << trx->streamlines->_offsets(1) - trx->streamlines->_offsets(0) + std::cout << "Vertices: " << trx->streamlines->data().size() / 3 << "\n"; + std::cout << "First vertex (x,y,z): " << trx->streamlines->data()(0, 0) << "," << trx->streamlines->data()(0, 1) + << "," << trx->streamlines->data()(0, 2) << "\n"; + std::cout << "Streamlines: " << trx->streamlines->offsets().size() << "\n"; + std::cout << "Vertices in first streamline: " << trx->streamlines->offsets()(1) - trx->streamlines->offsets()(0) << "\n"; std::cout << "dpg (data_per_group) items: " << trx->data_per_group.size() << "\n"; std::cout << "dps (data_per_streamline) items: " << trx->data_per_streamline.size() << "\n"; for (auto const &x : trx->data_per_streamline) { - std::cout << "'" << x.first << "' items: " << x.second->_matrix.size() << "\n"; + std::cout << "'" << x.first << "' items: " << x.second->matrix().size() << "\n"; } std::cout << "dpv (data_per_vertex) items:" << trx->data_per_vertex.size() << "\n"; for (auto const &x : trx->data_per_vertex) { - std::cout << "'" << x.first << "' items: " << x.second->_data.size() << "\n"; + std::cout << "'" << x.first << "' items: " << x.second->data().size() << "\n"; } std::cout << *trx << std::endl; -} \ No newline at end of file +} diff --git a/include/trx/detail/dtype_helpers.h b/include/trx/detail/dtype_helpers.h index 703edb2..e1a0d96 100644 --- a/include/trx/detail/dtype_helpers.h +++ b/include/trx/detail/dtype_helpers.h @@ -3,12 +3,41 @@ #include +#include #include #include namespace trx { namespace detail { +// Central helper that performs the ONE placement-new + reinterpret_cast needed +// to (re)bind an Eigen::Map to a new memory region. All other code should call +// this instead of scattering placement-new across the codebase. +// +// MapType must be an Eigen::Map> type. +template +inline void remap(MapType &map, void *data, int rows, int cols) { + using Scalar = typename MapType::Scalar; + new (&map) MapType(reinterpret_cast(data), rows, cols); // NOLINT +} + +// Overload for const data pointers (read-only maps). +template +inline void remap(MapType &map, const void *data, int rows, int cols) { + using Scalar = typename MapType::Scalar; + new (&map) MapType(const_cast(reinterpret_cast(data)), rows, cols); // NOLINT +} + +// Convenience overloads that unpack a (rows, cols) shape tuple. +template +inline void remap(MapType &map, void *data, const std::tuple &shape) { + remap(map, data, std::get<0>(shape), std::get<1>(shape)); +} +template +inline void remap(MapType &map, const void *data, const std::tuple &shape) { + remap(map, data, std::get<0>(shape), std::get<1>(shape)); +} + int _sizeof_dtype(const std::string &dtype); std::string _get_dtype(const std::string &dtype); bool _is_dtype_valid(const std::string &ext); diff --git a/include/trx/detail/exceptions.h b/include/trx/detail/exceptions.h new file mode 100644 index 0000000..54ee822 --- /dev/null +++ b/include/trx/detail/exceptions.h @@ -0,0 +1,41 @@ +#ifndef TRX_DETAIL_EXCEPTIONS_H +#define TRX_DETAIL_EXCEPTIONS_H + +#include +#include + +namespace trx { + +/// Base exception for all TRX library errors. +class TrxError : public std::runtime_error { +public: + using std::runtime_error::runtime_error; +}; + +/// I/O errors: zip failures, file not found, mmap errors, write failures. +class TrxIOError : public TrxError { +public: + using TrxError::TrxError; +}; + +/// Format errors: wrong sizes, missing fields, corrupt data, structural issues. +class TrxFormatError : public TrxError { +public: + using TrxError::TrxError; +}; + +/// Dtype errors: unsupported or mismatched data types. +class TrxDTypeError : public TrxError { +public: + using TrxError::TrxError; +}; + +/// Argument errors: invalid API arguments. +class TrxArgumentError : public TrxError { +public: + using TrxError::TrxError; +}; + +} // namespace trx + +#endif // TRX_DETAIL_EXCEPTIONS_H diff --git a/include/trx/detail/zip_raii.h b/include/trx/detail/zip_raii.h new file mode 100644 index 0000000..ef2eb2c --- /dev/null +++ b/include/trx/detail/zip_raii.h @@ -0,0 +1,102 @@ +#ifndef TRX_DETAIL_ZIP_RAII_H +#define TRX_DETAIL_ZIP_RAII_H + +#include +#include + +#include + +namespace trx { +namespace detail { + +/// RAII wrapper for zip_t*. Calls zip_discard() on destruction unless commit() is called. +class ZipArchive { +public: + ZipArchive() = default; + explicit ZipArchive(zip_t *z) : z_(z) {} + + ~ZipArchive() { + if (z_) { + if (committed_) { + zip_close(z_); + } else { + zip_discard(z_); + } + } + } + + ZipArchive(const ZipArchive &) = delete; + ZipArchive &operator=(const ZipArchive &) = delete; + ZipArchive(ZipArchive &&o) noexcept : z_(o.z_), committed_(o.committed_) { o.z_ = nullptr; } + ZipArchive &operator=(ZipArchive &&o) noexcept { + if (this != &o) { + if (z_) + zip_discard(z_); + z_ = o.z_; + committed_ = o.committed_; + o.z_ = nullptr; + } + return *this; + } + + /// Commit changes (calls zip_close on destruction instead of zip_discard). + void commit(const std::string &path = "") { + committed_ = true; + if (z_ && zip_close(z_) != 0) { + auto err = zip_strerror(z_); + z_ = nullptr; // prevent double-close + throw TrxIOError("Unable to close archive " + path + ": " + err); + } + z_ = nullptr; + } + + zip_t *get() const { return z_; } + explicit operator bool() const { return z_ != nullptr; } + + /// Release ownership without closing. + zip_t *release() { + auto *p = z_; + z_ = nullptr; + return p; + } + +private: + zip_t *z_ = nullptr; + bool committed_ = false; +}; + +/// RAII wrapper for zip_file_t*. Calls zip_fclose() on destruction. +class ZipFile { +public: + ZipFile() = default; + explicit ZipFile(zip_file_t *f) : f_(f) {} + + ~ZipFile() { + if (f_) + zip_fclose(f_); + } + + ZipFile(const ZipFile &) = delete; + ZipFile &operator=(const ZipFile &) = delete; + ZipFile(ZipFile &&o) noexcept : f_(o.f_) { o.f_ = nullptr; } + ZipFile &operator=(ZipFile &&o) noexcept { + if (this != &o) { + if (f_) + zip_fclose(f_); + f_ = o.f_; + o.f_ = nullptr; + } + return *this; + } + + zip_file_t *get() const { return f_; } + explicit operator bool() const { return f_ != nullptr; } + +private: + zip_file_t *f_ = nullptr; +}; + +} // namespace detail +} // namespace trx + +#endif // TRX_DETAIL_ZIP_RAII_H diff --git a/include/trx/trx.h b/include/trx/trx.h index 34c3bde..782c2e6 100644 --- a/include/trx/trx.h +++ b/include/trx/trx.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -32,6 +33,9 @@ #include #include +#include +#include + namespace trx { namespace fs = std::filesystem; } @@ -205,6 +209,14 @@ inline const std::array dtypes = {"float16", "float64"}; template struct ArraySequence { + // Public accessors + auto &data() { return _data; } + const auto &data() const { return _data; } + auto &offsets() { return _offsets; } + const auto &offsets() const { return _offsets; } + auto &lengths() { return _lengths; } + const auto &lengths() const { return _lengths; } + Eigen::Map> _data; Eigen::Map> _offsets; Eigen::Matrix _lengths; @@ -216,6 +228,10 @@ template struct ArraySequence { }; template struct MMappedMatrix { + // Public accessor + auto &matrix() { return _matrix; } + const auto &matrix() const { return _matrix; } + Eigen::Map> _matrix; mio::shared_mmap_sink mmap; @@ -363,6 +379,10 @@ template class TrxFile { return 0; } + /// Returns an empty TrxFile that inherits this file's header metadata but has + /// NB_VERTICES and NB_STREAMLINES reset to zero. + std::unique_ptr> make_empty_like() const; + /** * @brief Build per-streamline axis-aligned bounding boxes (AABB). * @@ -378,19 +398,19 @@ template class TrxFile { * The box is defined by min/max corners in TRX coordinates. * Returns a new TrxFile with positions, DPV/DPS, and groups remapped. * Optionally builds the AABB cache for the returned TrxFile. - */ - /** - * @brief Extract a subset of streamlines intersecting an axis-aligned box. * - * The box is defined by min/max corners in TRX coordinates. - * Returns a new TrxFile with positions, DPV/DPS, and groups remapped. - * Optionally builds the AABB cache for the returned TrxFile. + * If max_streamlines > 0 and more streamlines intersect the box than that + * limit, a random sample of exactly max_streamlines is returned instead of + * the full intersection. rng_seed controls the random draw so results are + * reproducible. The returned indices are sorted for efficient I/O. */ std::unique_ptr> query_aabb(const std::array &min_corner, const std::array &max_corner, const std::vector> *precomputed_aabbs = nullptr, - bool build_cache_for_result = false) const; + bool build_cache_for_result = false, + size_t max_streamlines = 0, + uint32_t rng_seed = 42) const; /** * @brief Extract a subset of streamlines by index. @@ -455,35 +475,18 @@ template class TrxFile { */ void remove_dpg_group(const std::string &group); -private: - mutable std::vector> aabb_cache_; - /** - * @brief Load a TrxFile from a zip archive. - * - * Internal: prefer TrxReader / with_trx_reader in public API. - */ + /// Load a TrxFile from a zip archive. static std::unique_ptr> load_from_zip(const std::string &path); - /** - * @brief Load a TrxFile from an on-disk directory. - * - * Internal: prefer TrxReader / with_trx_reader in public API. - */ + /// Load a TrxFile from an on-disk directory. static std::unique_ptr> load_from_directory(const std::string &path); - /** - * @brief Load a TrxFile from either a zip archive or directory. - * - * Internal: prefer TrxReader / with_trx_reader in public API. - */ + /// Load a TrxFile from either a zip archive or directory. static std::unique_ptr> load(const std::string &path); - /** - * @brief Get the real size of data (ignoring zeros of preallocation) - * - * @return std::tuple A tuple representing the index of the last streamline and the - * total length of all the streamlines - */ - std::tuple _get_real_len(); + + /// Access the backing directory path. + const std::string &uncompressed_folder_handle() const { return _uncompressed_folder_handle; } + std::string &uncompressed_folder_handle() { return _uncompressed_folder_handle; } /** * @brief Fill a TrxFile using another and start indexes (preallocation) @@ -498,6 +501,16 @@ template class TrxFile { std::tuple _copy_fixed_arrays_from(TrxFile
*trx, int strs_start = 0, int pts_start = 0, int nb_strs_to_copy = -1); int len(); + +private: + mutable std::vector> aabb_cache_; + /** + * @brief Get the real size of data (ignoring zeros of preallocation) + * + * @return std::tuple A tuple representing the index of the last streamline and the + * total length of all the streamlines + */ + std::tuple _get_real_len(); }; namespace detail { @@ -638,6 +651,14 @@ class AnyTrxFile { static AnyTrxFile load_from_zip(const std::string &path); static AnyTrxFile load_from_directory(const std::string &path); + /// Access the backing directory path. + const std::string &backing_directory() const { return _backing_directory; } + std::string &backing_directory() { return _backing_directory; } + + /// Access the uncompressed folder handle. + const std::string &uncompressed_folder_handle() const { return _uncompressed_folder_handle; } + std::string &uncompressed_folder_handle() { return _uncompressed_folder_handle; } + private: std::string _uncompressed_folder_handle; bool _owns_uncompressed_folder = false; @@ -703,24 +724,24 @@ class TrxStream { * InMemory keeps metadata in RAM until finalize (default). * OnDisk writes metadata to temp files and copies them at finalize. */ - void set_metadata_mode(MetadataMode mode); + TrxStream &set_metadata_mode(MetadataMode mode); /** * @brief Set max in-memory buffer size for metadata writes (bytes). * * Applies when MetadataMode::OnDisk. Larger buffers reduce write calls. */ - void set_metadata_buffer_max_bytes(std::size_t max_bytes); + TrxStream &set_metadata_buffer_max_bytes(std::size_t max_bytes); /** * @brief Set the VOXEL_TO_RASMM affine matrix in the header. */ - void set_voxel_to_rasmm(const Eigen::Matrix4f &affine); + TrxStream &set_voxel_to_rasmm(const Eigen::Matrix4f &affine); /** * @brief Set DIMENSIONS in the header. */ - void set_dimensions(const std::array &dims); + TrxStream &set_dimensions(const std::array &dims); /** * @brief Add per-streamline values (DPS) from an in-memory vector. @@ -836,11 +857,11 @@ class TrxStream { }; /** - * TODO: This function might be completely unecessary + * @brief Copy header fields from a JSON root (currently unused; candidate for removal). * - * @param[in] root a Json::Value root obtained from reading a header file with JsonCPP - * @param[out] header a header containing the same elements as the original root - * */ + * @param[in] root a Json::Value root obtained from reading a header file + * @return header a header containing the same elements as the original root + */ json assignHeader(const json &root); /** @@ -1020,9 +1041,8 @@ void allocate_file(const std::string &path, std::size_t size); * @param offset offset of the data within the file * @return mio::shared_mmap_sink */ -// TODO: ADD order?? -// TODO: change tuple to vector to support ND arrays? -// TODO: remove data type as that's done outside of this function +// Known limitations: only row-major order supported; shape uses tuple (sufficient for 2D); +// dtype parameter is used only for byte-size computation. mio::shared_mmap_sink _create_memmap(std::string filename, const std::tuple &shape, const std::string &mode = "r", diff --git a/include/trx/trx.tpp b/include/trx/trx.tpp index fff2f05..f3c7167 100644 --- a/include/trx/trx.tpp +++ b/include/trx/trx.tpp @@ -12,6 +12,48 @@ using Eigen::Index; using Eigen::Map; using Eigen::Matrix; using Eigen::RowMajor; +inline void mkdir_or_throw(const std::string &path) { + std::error_code ec; + trx::fs::create_directories(path, ec); + if (ec) { + throw TrxIOError("Could not create directory " + path); + } +} + +inline json default_header() { + std::vector> affine(4, std::vector(4, 0.0f)); + for (int i = 0; i < 4; i++) { + affine[i][i] = 1.0f; + } + json::object obj; + obj["VOXEL_TO_RASMM"] = affine; + obj["DIMENSIONS"] = std::vector{1, 1, 1}; + obj["NB_VERTICES"] = 0; + obj["NB_STREAMLINES"] = 0; + return json(obj); +} + +inline std::string folder_from_path(const std::string &elem_filename, const std::string &root) { + trx::fs::path elem_path(elem_filename); + trx::fs::path folder_path = elem_path.parent_path(); + std::string folder; + if (!root.empty()) { + trx::fs::path rel_path = elem_path.lexically_relative(trx::fs::path(root)); + std::string rel_str = rel_path.string(); + if (!rel_str.empty() && rel_str.rfind("..", 0) != 0) { + folder = rel_path.parent_path().string(); + } else { + folder = folder_path.string(); + } + } else { + folder = folder_path.string(); + } + if (folder == ".") { + folder.clear(); + } + return folder; +} + template void write_binary(const std::string &filename, const Matrix &matrix) { std::ofstream out(filename, std::ios::out | std::ios::binary | std::ios::trunc); typename Matrix::Index rows = matrix.rows(), cols = matrix.cols(); @@ -73,6 +115,14 @@ std::string _generate_filename_from_data(const Eigen::MatrixBase
&arr, std:: return new_filename; } +template +std::unique_ptr> TrxFile
::make_empty_like() const { + auto empty = std::make_unique>(); + empty->header = _json_set(this->header, "NB_VERTICES", 0); + empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); + return empty; +} + template TrxFile
::TrxFile(int nb_vertices, int nb_streamlines, const TrxFile
*init_as, std::string reference) { std::vector> affine(4); @@ -104,7 +154,7 @@ TrxFile
::TrxFile(int nb_vertices, int nb_streamlines, const TrxFile
*ini if (nb_vertices == 0 && nb_streamlines == 0) { if (init_as != nullptr) { // raise error here - throw std::invalid_argument("Can't us init_as without declaring nb_vertices and nb_streamlines"); + throw TrxArgumentError("Can't use init_as without declaring nb_vertices and nb_streamlines"); } // will remove as completely unecessary. using as placeholders @@ -132,7 +182,7 @@ TrxFile
::TrxFile(int nb_vertices, int nb_streamlines, const TrxFile
*ini trx->_owns_uncompressed_folder = false; trx->_uncompressed_folder_handle.clear(); } else { - throw std::invalid_argument("You must declare both NB_VERTICES AND NB_STREAMLINES"); + throw TrxArgumentError("You must declare both NB_VERTICES AND NB_STREAMLINES"); } json::object header_obj; @@ -181,17 +231,7 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve trx->streamlines = std::make_unique>(); trx->streamlines->mmap_pos = trx::_create_memmap(positions_filename, shape, "w+", positions_dtype); - // TODO: find a better way to get the dtype than using all these switch cases. - if (positions_dtype.compare("float16") == 0) { - new (&(trx->streamlines->_data)) Map>( - reinterpret_cast(trx->streamlines->mmap_pos.data()), std::get<0>(shape), std::get<1>(shape)); - } else if (positions_dtype.compare("float32") == 0) { - new (&(trx->streamlines->_data)) Map>( - reinterpret_cast(trx->streamlines->mmap_pos.data()), std::get<0>(shape), std::get<1>(shape)); - } else { - new (&(trx->streamlines->_data)) Map>( - reinterpret_cast(trx->streamlines->mmap_pos.data()), std::get<0>(shape), std::get<1>(shape)); - } + trx::detail::remap(trx->streamlines->_data, trx->streamlines->mmap_pos.data(), shape); std::string offsets_filename(tmp_dir); offsets_filename += "/offsets." + offsets_dtype; @@ -199,8 +239,7 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve std::tuple shape_off = std::make_tuple(nb_streamlines + 1, 1); trx->streamlines->mmap_off = trx::_create_memmap(offsets_filename, shape_off, "w+", offsets_dtype); - new (&(trx->streamlines->_offsets)) Map>( - reinterpret_cast(trx->streamlines->mmap_off.data()), std::get<0>(shape_off), std::get<1>(shape_off)); + trx::detail::remap(trx->streamlines->_offsets, trx->streamlines->mmap_off.data(), shape_off); trx->streamlines->_lengths.resize(nb_streamlines); trx->streamlines->_lengths.setZero(); @@ -210,19 +249,11 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve std::string dps_dirname; if (init_as->data_per_vertex.size() > 0) { dpv_dirname = tmp_dir + "/dpv/"; - std::error_code ec; - trx::fs::create_directories(dpv_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpv_dirname); - } + mkdir_or_throw(dpv_dirname); } if (init_as->data_per_streamline.size() > 0) { dps_dirname = tmp_dir + "/dps/"; - std::error_code ec; - trx::fs::create_directories(dps_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dps_dirname); - } + mkdir_or_throw(dps_dirname); } for (auto const &x : init_as->data_per_vertex) { @@ -245,21 +276,11 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve std::tuple dpv_shape = std::make_tuple(rows, cols); trx->data_per_vertex[x.first] = std::make_unique>(); trx->data_per_vertex[x.first]->mmap_pos = trx::_create_memmap(dpv_filename, dpv_shape, "w+", dpv_dtype); - if (dpv_dtype.compare("float16") == 0) { - new (&(trx->data_per_vertex[x.first]->_data)) Map>( - reinterpret_cast(trx->data_per_vertex[x.first]->mmap_pos.data()), rows, cols); - } else if (dpv_dtype.compare("float32") == 0) { - new (&(trx->data_per_vertex[x.first]->_data)) Map>( - reinterpret_cast(trx->data_per_vertex[x.first]->mmap_pos.data()), rows, cols); - } else { - new (&(trx->data_per_vertex[x.first]->_data)) Map>( - reinterpret_cast(trx->data_per_vertex[x.first]->mmap_pos.data()), rows, cols); - } + trx::detail::remap(trx->data_per_vertex[x.first]->_data, trx->data_per_vertex[x.first]->mmap_pos.data(), rows, + cols); - new (&(trx->data_per_vertex[x.first]->_offsets)) - Map>(trx->streamlines->_offsets.data(), - int(trx->streamlines->_offsets.rows()), - int(trx->streamlines->_offsets.cols())); + trx::detail::remap(trx->data_per_vertex[x.first]->_offsets, trx->streamlines->_offsets.data(), + int(trx->streamlines->_offsets.rows()), int(trx->streamlines->_offsets.cols())); trx->data_per_vertex[x.first]->_lengths = trx->streamlines->_lengths; } @@ -285,16 +306,8 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve trx->data_per_streamline[x.first]->mmap = trx::_create_memmap(dps_filename, dps_shape, std::string("w+"), dps_dtype); - if (dps_dtype.compare("float16") == 0) { - new (&(trx->data_per_streamline[x.first]->_matrix)) Map>( - reinterpret_cast(trx->data_per_streamline[x.first]->mmap.data()), rows, cols); - } else if (dps_dtype.compare("float32") == 0) { - new (&(trx->data_per_streamline[x.first]->_matrix)) Map>( - reinterpret_cast(trx->data_per_streamline[x.first]->mmap.data()), rows, cols); - } else { - new (&(trx->data_per_streamline[x.first]->_matrix)) Map>( - reinterpret_cast(trx->data_per_streamline[x.first]->mmap.data()), rows, cols); - } + trx::detail::remap(trx->data_per_streamline[x.first]->_matrix, trx->data_per_streamline[x.first]->mmap.data(), + rows, cols); } } @@ -317,7 +330,8 @@ TrxFile
::_create_trx_from_pointer(json header, std::string filename; - // TODO: Fix this hack of iterating through dictionary in reverse to get main files read first + // Iterate in reverse so that "positions" and "offsets" (which sort after "dpv"/"dps"/"groups") + // are processed first, before DPS/DPV entries that depend on them being initialized. for (auto x = dict_pointer_size.rbegin(); x != dict_pointer_size.rend(); ++x) { std::string elem_filename = x->first; @@ -327,38 +341,18 @@ TrxFile
::_create_trx_from_pointer(json header, filename = elem_filename; } - trx::fs::path elem_path(elem_filename); - trx::fs::path folder_path = elem_path.parent_path(); - std::string folder; - if (!root.empty()) { - trx::fs::path rel_path = elem_path.lexically_relative(trx::fs::path(root)); - std::string rel_str = rel_path.string(); - if (!rel_str.empty() && rel_str.rfind("..", 0) != 0) { - folder = rel_path.parent_path().string(); - } else { - folder = folder_path.string(); - } - } else { - folder = folder_path.string(); - } - if (folder == ".") { - folder.clear(); - } + std::string folder = folder_from_path(elem_filename, root); - // _split_ext_with_dimensionality - std::tuple base_tuple = trx::detail::_split_ext_with_dimensionality(elem_filename); - std::string base(std::get<0>(base_tuple)); - int dim = std::get<1>(base_tuple); - std::string ext(std::get<2>(base_tuple)); + auto [base, dim, ext] = trx::detail::_split_ext_with_dimensionality(elem_filename); long long mem_adress = std::get<0>(x->second); long long size = std::get<1>(x->second); - if (base.compare("positions") == 0 && (folder.compare("") == 0 || folder.compare(".") == 0)) { + if (base == "positions" && (folder.empty() || folder == ".")) { const auto nb_vertices = static_cast(trx->header["NB_VERTICES"].int_value()); const auto expected = nb_vertices * 3; if (size != expected || dim != 3) { - throw std::invalid_argument("Wrong data size/dimensionality: size=" + std::to_string(size) + + throw TrxFormatError("Wrong data size/dimensionality: size=" + std::to_string(size) + " expected=" + std::to_string(expected) + " dim=" + std::to_string(dim) + " filename=" + elem_filename); } @@ -367,25 +361,14 @@ TrxFile
::_create_trx_from_pointer(json header, trx->streamlines->mmap_pos = trx::_create_memmap(filename, shape, "r+", ext.substr(1, ext.size() - 1), mem_adress); - // TODO: find a better way to get the dtype than using all these switch cases. Also - // refactor into function as per specifications, positions can only be floats - if (ext.compare("float16") == 0) { - new (&(trx->streamlines->_data)) Map>( - reinterpret_cast(trx->streamlines->mmap_pos.data()), std::get<0>(shape), std::get<1>(shape)); - } else if (ext.compare("float32") == 0) { - new (&(trx->streamlines->_data)) Map>( - reinterpret_cast(trx->streamlines->mmap_pos.data()), std::get<0>(shape), std::get<1>(shape)); - } else { - new (&(trx->streamlines->_data)) Map>( - reinterpret_cast(trx->streamlines->mmap_pos.data()), std::get<0>(shape), std::get<1>(shape)); - } + trx::detail::remap(trx->streamlines->_data, trx->streamlines->mmap_pos.data(), shape); } - else if (base.compare("offsets") == 0 && (folder.compare("") == 0 || folder.compare(".") == 0)) { + else if (base == "offsets" && (folder.empty() || folder == ".")) { const auto nb_streamlines = static_cast(trx->header["NB_STREAMLINES"].int_value()); const auto expected = nb_streamlines + 1; if (size != expected || dim != 1) { - throw std::invalid_argument("Wrong offsets size/dimensionality: size=" + std::to_string(size) + + throw TrxFormatError("Wrong offsets size/dimensionality: size=" + std::to_string(size) + " expected=" + std::to_string(expected) + " dim=" + std::to_string(dim) + " filename=" + elem_filename); } @@ -394,18 +377,16 @@ TrxFile
::_create_trx_from_pointer(json header, std::tuple shape = std::make_tuple(nb_str + 1, 1); trx->streamlines->mmap_off = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); - if (ext.compare("uint64") == 0) { - new (&(trx->streamlines->_offsets)) Map>( - reinterpret_cast(trx->streamlines->mmap_off.data()), std::get<0>(shape), std::get<1>(shape)); - } else if (ext.compare("uint32") == 0) { + if (ext == "uint64") { + trx::detail::remap(trx->streamlines->_offsets, trx->streamlines->mmap_off.data(), shape); + } else if (ext == "uint32") { trx->streamlines->_offsets_owned.resize(std::get<0>(shape)); - auto *src = reinterpret_cast(trx->streamlines->mmap_off.data()); + auto *src = reinterpret_cast(trx->streamlines->mmap_off.data()); // NOLINT for (int i = 0; i < std::get<0>(shape); ++i) trx->streamlines->_offsets_owned[static_cast(i)] = static_cast(src[i]); - new (&(trx->streamlines->_offsets)) Map>( - trx->streamlines->_offsets_owned.data(), std::get<0>(shape), std::get<1>(shape)); + trx::detail::remap(trx->streamlines->_offsets, trx->streamlines->_offsets_owned.data(), shape); } else { - throw std::invalid_argument("Unsupported offsets datatype: " + ext); + throw TrxDTypeError("Unsupported offsets datatype: " + ext); } Matrix offsets = trx->streamlines->_offsets; @@ -413,69 +394,34 @@ TrxFile
::_create_trx_from_pointer(json header, trx::detail::_compute_lengths(offsets, static_cast(trx->header["NB_VERTICES"].int_value())); } - else if (folder.compare("dps") == 0) { + else if (folder == "dps") { std::tuple shape; trx->data_per_streamline[base] = std::make_unique>(); int nb_scalar = size / static_cast(trx->header["NB_STREAMLINES"].int_value()); if (size % static_cast(trx->header["NB_STREAMLINES"].int_value()) != 0 || nb_scalar != dim) { - - throw std::invalid_argument("Wrong dps size/dimensionality"); + throw TrxFormatError("Wrong dps size/dimensionality"); } else { shape = std::make_tuple(static_cast(trx->header["NB_STREAMLINES"].int_value()), nb_scalar); } trx->data_per_streamline[base]->mmap = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); - - if (ext.compare("float16") == 0) { - new (&(trx->data_per_streamline[base]->_matrix)) - Map>(reinterpret_cast(trx->data_per_streamline[base]->mmap.data()), - std::get<0>(shape), - std::get<1>(shape)); - } else if (ext.compare("float32") == 0) { - new (&(trx->data_per_streamline[base]->_matrix)) - Map>(reinterpret_cast(trx->data_per_streamline[base]->mmap.data()), - std::get<0>(shape), - std::get<1>(shape)); - } else { - new (&(trx->data_per_streamline[base]->_matrix)) Map>( - reinterpret_cast(trx->data_per_streamline[base]->mmap.data()), - std::get<0>(shape), - std::get<1>(shape)); - } + trx::detail::remap(trx->data_per_streamline[base]->_matrix, trx->data_per_streamline[base]->mmap.data(), shape); } - else if (folder.compare("dpv") == 0) { + else if (folder == "dpv") { std::tuple shape; trx->data_per_vertex[base] = std::make_unique>(); int nb_scalar = size / static_cast(trx->header["NB_VERTICES"].int_value()); if (size % static_cast(trx->header["NB_VERTICES"].int_value()) != 0 || nb_scalar != dim) { - - throw std::invalid_argument("Wrong dpv size/dimensionality"); + throw TrxFormatError("Wrong dpv size/dimensionality"); } else { shape = std::make_tuple(static_cast(trx->header["NB_VERTICES"].int_value()), nb_scalar); } trx->data_per_vertex[base]->mmap_pos = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); - - if (ext.compare("float16") == 0) { - new (&(trx->data_per_vertex[base]->_data)) - Map>(reinterpret_cast(trx->data_per_vertex[base]->mmap_pos.data()), - std::get<0>(shape), - std::get<1>(shape)); - } else if (ext.compare("float32") == 0) { - new (&(trx->data_per_vertex[base]->_data)) - Map>(reinterpret_cast(trx->data_per_vertex[base]->mmap_pos.data()), - std::get<0>(shape), - std::get<1>(shape)); - } else { - new (&(trx->data_per_vertex[base]->_data)) Map>( - reinterpret_cast(trx->data_per_vertex[base]->mmap_pos.data()), - std::get<0>(shape), - std::get<1>(shape)); - } - - new (&(trx->data_per_vertex[base]->_offsets)) - Map>(trx->streamlines->_offsets.data(), std::get<0>(shape), std::get<1>(shape)); + trx::detail::remap(trx->data_per_vertex[base]->_data, trx->data_per_vertex[base]->mmap_pos.data(), shape); + trx::detail::remap(trx->data_per_vertex[base]->_offsets, trx->streamlines->_offsets.data(), + int(trx->streamlines->_offsets.rows()), int(trx->streamlines->_offsets.cols())); trx->data_per_vertex[base]->_lengths = trx->streamlines->_lengths; } @@ -484,7 +430,7 @@ TrxFile
::_create_trx_from_pointer(json header, if (size != dim) { - throw std::invalid_argument("Wrong dpg size/dimensionality"); + throw TrxFormatError("Wrong dpg size/dimensionality"); } else { shape = std::make_tuple(1, static_cast(size)); } @@ -495,188 +441,142 @@ TrxFile
::_create_trx_from_pointer(json header, trx->data_per_group[sub_folder][data_name] = std::make_unique>(); trx->data_per_group[sub_folder][data_name]->mmap = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); - if (ext.compare("float16") == 0) { - new (&(trx->data_per_group[sub_folder][data_name]->_matrix)) Map>( - reinterpret_cast(trx->data_per_group[sub_folder][data_name]->mmap.data()), - std::get<0>(shape), - std::get<1>(shape)); - } else if (ext.compare("float32") == 0) { - new (&(trx->data_per_group[sub_folder][data_name]->_matrix)) Map>( - reinterpret_cast(trx->data_per_group[sub_folder][data_name]->mmap.data()), - std::get<0>(shape), - std::get<1>(shape)); - } else { - new (&(trx->data_per_group[sub_folder][data_name]->_matrix)) Map>( - reinterpret_cast(trx->data_per_group[sub_folder][data_name]->mmap.data()), - std::get<0>(shape), - std::get<1>(shape)); - } + trx::detail::remap(trx->data_per_group[sub_folder][data_name]->_matrix, + trx->data_per_group[sub_folder][data_name]->mmap.data(), shape); } - else if (folder.compare("groups") == 0) { + else if (folder == "groups") { std::tuple shape; if (dim != 1) { - throw std::invalid_argument("Wrong group dimensionality"); + throw TrxFormatError("Wrong group dimensionality"); } else { shape = std::make_tuple(static_cast(size), 1); } trx->groups[base] = std::make_unique>(); trx->groups[base]->mmap = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); - new (&(trx->groups[base]->_matrix)) Map>( - reinterpret_cast(trx->groups[base]->mmap.data()), std::get<0>(shape), std::get<1>(shape)); + trx::detail::remap(trx->groups[base]->_matrix, trx->groups[base]->mmap.data(), shape); } else { - throw std::invalid_argument("Entry is not part of a valid TRX structure: " + elem_filename); + throw TrxFormatError("Entry is not part of a valid TRX structure: " + elem_filename); } } if (trx->streamlines->_data.size() == 0 || trx->streamlines->_offsets.size() == 0) { - throw std::invalid_argument("Missing essential data."); + throw TrxFormatError("Missing essential data."); } return trx; } -// TODO: Major refactoring template std::unique_ptr> TrxFile
::deepcopy() { if (!this->streamlines || this->streamlines->_data.size() == 0 || this->streamlines->_offsets.size() == 0) { auto empty_copy = std::make_unique>(); empty_copy->header = this->header; return empty_copy; } - std::string tmp_dir = make_temp_dir("trx"); - - std::string header = tmp_dir + SEPARATOR + "header.json"; - std::ofstream out_json(header); - // TODO: Definitely a better way to deepcopy + // Determine effective counts (handle sliced/non-copy-safe data) json tmp_header = this->header; - - auto to_dump = std::make_unique>(); - // TODO: Verify that this is indeed a deep copy - new (&(to_dump->_data)) Matrix(this->streamlines->_data); - new (&(to_dump->_offsets)) Matrix(this->streamlines->_offsets); - new (&(to_dump->_lengths)) Matrix(this->streamlines->_lengths); - + int nb_streamlines, nb_vertices; if (!this->_copy_safe) { - const int nb_streamlines = to_dump->_offsets.size() > 0 ? static_cast(to_dump->_offsets.size() - 1) : 0; - const int nb_vertices = static_cast(to_dump->_data.size() / 3); + nb_streamlines = static_cast(this->num_streamlines()); + nb_vertices = static_cast(this->streamlines->_data.size() / 3); tmp_header = _json_set(tmp_header, "NB_STREAMLINES", nb_streamlines); tmp_header = _json_set(tmp_header, "NB_VERTICES", nb_vertices); - } - // Ensure sentinel is correct before persisting - if (to_dump->_offsets.size() > 0) { - to_dump->_offsets(to_dump->_offsets.size() - 1) = static_cast(tmp_header["NB_VERTICES"].int_value()); - } - if (out_json.is_open()) { - out_json << tmp_header.dump() << std::endl; - out_json.close(); + } else { + nb_streamlines = tmp_header["NB_STREAMLINES"].int_value(); + nb_vertices = tmp_header["NB_VERTICES"].int_value(); } - std::string pos_rootfn = tmp_dir + SEPARATOR + "positions"; - std::string positions_filename = _generate_filename_from_data(to_dump->_data, pos_rootfn); + // Allocate a fresh TrxFile with memory-mapped storage + auto copy = _initialize_empty_trx
(nb_streamlines, nb_vertices, this); - write_binary(positions_filename, to_dump->_data); + // Copy header + copy->header = tmp_header; - std::string off_rootfn = tmp_dir + SEPARATOR + "offsets"; - std::string offsets_filename = _generate_filename_from_data(to_dump->_offsets, off_rootfn); + // Copy positions + copy->streamlines->_data = this->streamlines->_data; - write_binary(offsets_filename, to_dump->_offsets); + // Copy offsets + copy->streamlines->_offsets = this->streamlines->_offsets; + // Ensure sentinel is correct + if (copy->streamlines->_offsets.size() > 0) { + copy->streamlines->_offsets(copy->streamlines->_offsets.size() - 1) = static_cast(nb_vertices); + } - if (this->data_per_vertex.size() > 0) { - std::string dpv_dirname = tmp_dir + SEPARATOR + "dpv" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(dpv_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpv_dirname); - } + // Copy lengths + copy->streamlines->_lengths = this->streamlines->_lengths; + + // Copy DPS + for (auto const &kv : this->data_per_streamline) { + auto it = copy->data_per_streamline.find(kv.first); + if (it != copy->data_per_streamline.end()) { + it->second->_matrix = kv.second->_matrix; } - for (auto const &x : this->data_per_vertex) { - Matrix dpv_todump = x.second->_data; - std::string dpv_filename = dpv_dirname + x.first; - dpv_filename = _generate_filename_from_data(dpv_todump, dpv_filename); + } - write_binary(dpv_filename, dpv_todump); + // Copy DPV + for (auto const &kv : this->data_per_vertex) { + auto it = copy->data_per_vertex.find(kv.first); + if (it != copy->data_per_vertex.end()) { + it->second->_data = kv.second->_data; + // _offsets is already correctly bound to copy->streamlines->_offsets by _initialize_empty_trx + it->second->_lengths = kv.second->_lengths; } } - if (this->data_per_streamline.size() > 0) { - std::string dps_dirname = tmp_dir + SEPARATOR + "dps" + SEPARATOR; + // Copy groups (not covered by _initialize_empty_trx) + std::string tmp_dir = copy->_uncompressed_folder_handle; + if (!this->groups.empty()) { + std::string groups_dirname = tmp_dir + SEPARATOR + "groups" + SEPARATOR; { std::error_code ec; - trx::fs::create_directories(dps_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dps_dirname); - } + trx::fs::create_directories(groups_dirname, ec); } - for (auto const &x : this->data_per_streamline) { - Matrix dps_todump = x.second->_matrix; - std::string dps_filename = dps_dirname + x.first; - dps_filename = _generate_filename_from_data(dps_todump, dps_filename); + for (auto const &kv : this->groups) { + std::string group_dtype = dtype_from_scalar(); + int rows = static_cast(kv.second->_matrix.rows()); + int cols = static_cast(kv.second->_matrix.cols()); + std::string group_filename = groups_dirname + kv.first; + group_filename = _generate_filename_from_data(kv.second->_matrix, group_filename); - write_binary(dps_filename, dps_todump); + std::tuple group_shape = std::make_tuple(rows, cols); + copy->groups[kv.first] = std::make_unique>(); + copy->groups[kv.first]->mmap = _create_memmap(group_filename, group_shape, "w+", group_dtype); + trx::detail::remap(copy->groups[kv.first]->_matrix, copy->groups[kv.first]->mmap.data(), rows, cols); + copy->groups[kv.first]->_matrix = kv.second->_matrix; } } - if (this->groups.size() > 0) { - std::string groups_dirname = tmp_dir + SEPARATOR + "groups" + SEPARATOR; + // Copy DPG (not covered by _initialize_empty_trx) + for (auto const &group_kv : this->data_per_group) { + std::string dpg_dirname = tmp_dir + SEPARATOR + "dpg" + SEPARATOR; + std::string dpg_subdirname = dpg_dirname + group_kv.first; { std::error_code ec; - trx::fs::create_directories(groups_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + groups_dirname); - } + trx::fs::create_directories(dpg_subdirname, ec); } + for (auto const &field : group_kv.second) { + std::string dpg_dtype = dtype_from_scalar
(); + int rows = static_cast(field.second->_matrix.rows()); + int cols = static_cast(field.second->_matrix.cols()); + std::string dpg_filename = dpg_subdirname + SEPARATOR + field.first; + dpg_filename = _generate_filename_from_data(field.second->_matrix, dpg_filename); - for (auto const &x : this->groups) { - Matrix group_todump = x.second->_matrix; - std::string group_filename = groups_dirname + x.first; - group_filename = _generate_filename_from_data(group_todump, group_filename); - - write_binary(group_filename, group_todump); - - if (this->data_per_group.find(x.first) == this->data_per_group.end()) { - continue; - } - - for (auto const &y : this->data_per_group[x.first]) { - std::string dpg_dirname = tmp_dir + SEPARATOR + "dpg" + SEPARATOR; - std::string dpg_subdirname = dpg_dirname + x.first; - std::error_code ec; - if (!trx::fs::exists(dpg_dirname, ec)) { - ec.clear(); - trx::fs::create_directories(dpg_dirname, ec); - } - if (ec) { - throw std::runtime_error("Could not create directory " + dpg_dirname); - } - ec.clear(); - if (!trx::fs::exists(dpg_subdirname, ec)) { - ec.clear(); - trx::fs::create_directories(dpg_subdirname, ec); - } - if (ec) { - throw std::runtime_error("Could not create directory " + dpg_subdirname); - } - - Matrix dpg_todump = this->data_per_group[x.first][y.first]->_matrix; - std::string dpg_filename = dpg_subdirname + SEPARATOR + y.first; - dpg_filename = _generate_filename_from_data(dpg_todump, dpg_filename); - - write_binary(dpg_filename, dpg_todump); - } + std::tuple dpg_shape = std::make_tuple(rows, cols); + copy->data_per_group[group_kv.first][field.first] = std::make_unique>(); + copy->data_per_group[group_kv.first][field.first]->mmap = + _create_memmap(dpg_filename, dpg_shape, "w+", dpg_dtype); + trx::detail::remap(copy->data_per_group[group_kv.first][field.first]->_matrix, + copy->data_per_group[group_kv.first][field.first]->mmap.data(), rows, cols); + copy->data_per_group[group_kv.first][field.first]->_matrix = field.second->_matrix; } } - auto copy_trx = TrxFile
::load_from_directory(tmp_dir); - copy_trx->_uncompressed_folder_handle = tmp_dir; - copy_trx->_owns_uncompressed_folder = true; - - return copy_trx; + return copy; } -// TODO: verify that this function is actually necessary (there should not be preallocation zeros -// afaik) +/// Compute the used range in a preallocated TrxFile by finding the last non-zero length. +/// Returns (nb_streamlines_used, nb_vertices_used). template std::tuple TrxFile
::_get_real_len() { if (this->streamlines->_lengths.size() == 0) return std::make_tuple(0, 0); @@ -733,10 +633,9 @@ TrxFile
::_copy_fixed_arrays_from(TrxFile
*trx, int strs_start, int pts_s this->data_per_vertex[x.first]->_data.block( pts_start, 0, curr_pts_len, this->data_per_vertex[x.first]->_data.cols()) = trx->data_per_vertex[x.first]->_data.block(0, 0, curr_pts_len, trx->data_per_vertex[x.first]->_data.cols()); - new (&(this->data_per_vertex[x.first]->_offsets)) - Map>(trx->data_per_vertex[x.first]->_offsets.data(), - trx->data_per_vertex[x.first]->_offsets.rows(), - trx->data_per_vertex[x.first]->_offsets.cols()); + trx::detail::remap(this->data_per_vertex[x.first]->_offsets, trx->data_per_vertex[x.first]->_offsets.data(), + static_cast(trx->data_per_vertex[x.first]->_offsets.rows()), + static_cast(trx->data_per_vertex[x.first]->_offsets.cols())); this->data_per_vertex[x.first]->_lengths = trx->data_per_vertex[x.first]->_lengths; } @@ -792,7 +691,7 @@ template // resize is a no-op. void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { if (!this->_copy_safe) { - throw std::invalid_argument("Cannot resize a sliced dataset."); + throw TrxArgumentError("Cannot resize a sliced dataset."); } std::tuple sp_end = this->_get_real_len(); @@ -831,13 +730,7 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { if (this->groups.size() > 0) { std::string group_dir = tmp_dir + SEPARATOR + "groups" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(group_dir, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + group_dir); - } - } + mkdir_or_throw(group_dir); for (auto const &x : this->groups) { std::string group_dtype = dtype_from_scalar(); @@ -863,10 +756,7 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { trx->groups[x.first] = std::make_unique>(); trx->groups[x.first]->mmap = trx::_create_memmap(group_name, group_shape, "w+", group_dtype); - new (&(trx->groups[x.first]->_matrix)) - Map>(reinterpret_cast(trx->groups[x.first]->mmap.data()), - std::get<0>(group_shape), - std::get<1>(group_shape)); + trx::detail::remap(trx->groups[x.first]->_matrix, trx->groups[x.first]->mmap.data(), group_shape); // update values for (int i = 0; i < trx->groups[x.first]->_matrix.rows(); ++i) { @@ -883,25 +773,12 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { } if (this->data_per_group.size() > 0) { - // really need to refactor all these mkdirs std::string dpg_dir = tmp_dir + SEPARATOR + "dpg" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(dpg_dir, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpg_dir); - } - } + mkdir_or_throw(dpg_dir); for (auto const &x : this->data_per_group) { std::string dpg_subdir = dpg_dir + x.first; - { - std::error_code ec; - trx::fs::create_directories(dpg_subdir, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpg_subdir); - } - } + mkdir_or_throw(dpg_subdir); if (trx->data_per_group.find(x.first) == trx->data_per_group.end()) { trx->data_per_group.emplace(x.first, std::map>>{}); @@ -922,10 +799,8 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { } trx->data_per_group[x.first][y.first]->mmap = _create_memmap(dpg_filename, dpg_shape, "w+", dpg_dtype); - new (&(trx->data_per_group[x.first][y.first]->_matrix)) Map>( - reinterpret_cast(trx->data_per_group[x.first][y.first]->mmap.data()), - std::get<0>(dpg_shape), - std::get<1>(dpg_shape)); + trx::detail::remap(trx->data_per_group[x.first][y.first]->_matrix, + trx->data_per_group[x.first][y.first]->mmap.data(), dpg_shape); // update values for (int i = 0; i < trx->data_per_group[x.first][y.first]->_matrix.rows(); ++i) { @@ -942,13 +817,12 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { template std::unique_ptr> TrxFile
::load_from_zip(const std::string &filename) { int errorp = 0; - zip_t *zf = open_zip_for_read(filename, errorp); - if (zf == nullptr) { - throw std::runtime_error("Could not open zip file: " + filename); + detail::ZipArchive zf(open_zip_for_read(filename, errorp)); + if (!zf) { + throw TrxIOError("Could not open zip file: " + filename); } - std::string temp_dir = extract_zip_to_directory(zf); - zip_close(zf); + std::string temp_dir = extract_zip_to_directory(zf.get()); auto trx = TrxFile
::load_from_directory(temp_dir); trx->_uncompressed_folder_handle = temp_dir; @@ -1003,14 +877,14 @@ template std::unique_ptr> TrxFile
::load_from_direc detail += "]"; } } - throw std::runtime_error(detail); + throw TrxIOError(detail); } std::string jstream((std::istreambuf_iterator(header_file)), std::istreambuf_iterator()); header_file.close(); std::string err; json header = json::parse(jstream, err); if (!err.empty()) { - throw std::runtime_error("Failed to parse header.json: " + err); + throw TrxIOError("Failed to parse header.json: " + err); } std::map> files_pointer_size; @@ -1025,7 +899,7 @@ template std::unique_ptr> TrxFile
::load_from_direc template std::unique_ptr> TrxFile
::load(const std::string &path) { trx::fs::path input(path); if (!trx::fs::exists(input)) { - throw std::runtime_error("Input path does not exist: " + path); + throw TrxIOError("Input path does not exist: " + path); } std::error_code ec; if (trx::fs::is_directory(input, ec) && !ec) { @@ -1078,15 +952,15 @@ template void TrxFile
::save(const std::string &filename, zip_u template void TrxFile
::normalize_for_save() { if (!this->streamlines) { - throw std::runtime_error("Cannot normalize TRX without streamline data"); + throw TrxFormatError("Cannot normalize TRX without streamline data"); } if (this->streamlines->_offsets.size() == 0) { - throw std::runtime_error("Cannot normalize TRX without offsets data"); + throw TrxFormatError("Cannot normalize TRX without offsets data"); } const size_t offsets_count = static_cast(this->streamlines->_offsets.size()); if (offsets_count < 1) { - throw std::runtime_error("Invalid offsets array"); + throw TrxFormatError("Invalid offsets array"); } const size_t total_streamlines = offsets_count - 1; const uint64_t data_rows = static_cast(this->streamlines->_data.rows()); @@ -1104,73 +978,69 @@ template void TrxFile
::normalize_for_save() { const uint64_t used_vertices = static_cast(this->streamlines->_offsets(static_cast(used_streamlines))); if (used_vertices > data_rows) { - throw std::runtime_error("TRX offsets exceed positions row count"); + throw TrxFormatError("TRX offsets exceed positions row count"); } if (used_vertices > static_cast(std::numeric_limits::max()) || used_streamlines > static_cast(std::numeric_limits::max())) { - throw std::runtime_error("TRX normalize_for_save exceeds supported int range"); + throw TrxFormatError("TRX normalize_for_save exceeds supported int range"); } if (used_streamlines < total_streamlines || used_vertices < data_rows) { this->resize(static_cast(used_streamlines), static_cast(used_vertices)); } - const size_t normalized_streamlines = static_cast(this->streamlines->_offsets.size()) - 1; + const size_t normalized_streamlines = this->num_streamlines(); for (size_t i = 0; i < normalized_streamlines; ++i) { const uint64_t curr = static_cast(this->streamlines->_offsets(static_cast(i))); const uint64_t next = static_cast(this->streamlines->_offsets(static_cast(i + 1))); if (next < curr) { - throw std::runtime_error("TRX offsets must be monotonically increasing"); + throw TrxFormatError("TRX offsets must be monotonically increasing"); } const uint64_t diff = next - curr; if (diff > static_cast(std::numeric_limits::max())) { - throw std::runtime_error("TRX streamline length exceeds uint32 range"); + throw TrxFormatError("TRX streamline length exceeds uint32 range"); } this->streamlines->_lengths(static_cast(i)) = static_cast(diff); } - const uint64_t sentinel = static_cast( - this->streamlines->_offsets(static_cast(this->streamlines->_offsets.size() - 1))); this->header = _json_set(this->header, "NB_STREAMLINES", static_cast(normalized_streamlines)); - this->header = _json_set(this->header, "NB_VERTICES", static_cast(sentinel)); + this->header = _json_set(this->header, "NB_VERTICES", static_cast(this->num_vertices())); } template void TrxFile
::save(const std::string &filename, const TrxSaveOptions &options) { std::string ext = get_ext(filename); if (ext.size() > 0 && ext != "zip" && ext != "trx") { - throw std::invalid_argument("Unsupported extension: " + ext); + throw TrxDTypeError("Unsupported extension: " + ext); } TrxFile
*save_trx = this; if (!save_trx->streamlines || save_trx->streamlines->_offsets.size() == 0) { - throw std::runtime_error("Cannot save TRX without offsets data"); + throw TrxFormatError("Cannot save TRX without offsets data"); } if (save_trx->header["NB_STREAMLINES"].is_number()) { const auto nb_streamlines = static_cast(save_trx->header["NB_STREAMLINES"].int_value()); if (save_trx->streamlines->_offsets.size() != static_cast(nb_streamlines + 1)) { - throw std::runtime_error("TRX offsets size does not match NB_STREAMLINES"); + throw TrxFormatError("TRX offsets size does not match NB_STREAMLINES"); } } if (save_trx->header["NB_VERTICES"].is_number()) { const auto nb_vertices = static_cast(save_trx->header["NB_VERTICES"].int_value()); - const auto last = - static_cast(save_trx->streamlines->_offsets(save_trx->streamlines->_offsets.size() - 1)); + const auto last = static_cast(save_trx->num_vertices()); if (last != nb_vertices) { - throw std::runtime_error("TRX offsets sentinel does not match NB_VERTICES"); + throw TrxFormatError("TRX offsets sentinel does not match NB_VERTICES"); } } for (Eigen::Index i = 1; i < save_trx->streamlines->_offsets.size(); ++i) { if (save_trx->streamlines->_offsets(i) < save_trx->streamlines->_offsets(i - 1)) { - throw std::runtime_error("TRX offsets must be monotonically increasing"); + throw TrxFormatError("TRX offsets must be monotonically increasing"); } } if (save_trx->streamlines->_data.size() > 0) { - const auto last = - static_cast(save_trx->streamlines->_offsets(save_trx->streamlines->_offsets.size() - 1)); + const auto last = static_cast(save_trx->num_vertices()); if (last != static_cast(save_trx->streamlines->_data.rows())) { - throw std::runtime_error("TRX positions row count does not match offsets sentinel"); + throw TrxFormatError("TRX positions row count does not match offsets sentinel"); } } std::string tmp_dir_name = save_trx->_uncompressed_folder_handle; @@ -1179,7 +1049,7 @@ template void TrxFile
::save(const std::string &filename, const const std::string header_path = tmp_dir_name + SEPARATOR + "header.json"; std::ofstream out_json(header_path, std::ios::out | std::ios::trunc); if (!out_json.is_open()) { - throw std::runtime_error("Failed to write header.json to: " + header_path); + throw TrxIOError("Failed to write header.json to: " + header_path); } out_json << save_trx->header.dump() << std::endl; out_json.close(); @@ -1221,44 +1091,37 @@ template void TrxFile
::save(const std::string &filename, const } int errorp; - zip_t *zf; - if ((zf = zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)) == nullptr) { - throw std::runtime_error("Could not open archive " + filename + ": " + strerror(errorp)); - } else { - zip_from_folder(zf, tmp_dir_name, tmp_dir_name, options.compression_standard, nullptr); - if (zip_close(zf) != 0) { - throw std::runtime_error("Unable to close archive " + filename + ": " + zip_strerror(zf)); - } + detail::ZipArchive zf(zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)); + if (!zf) { + throw TrxIOError("Could not open archive " + filename + ": " + strerror(errorp)); } + zip_from_folder(zf.get(), tmp_dir_name, tmp_dir_name, options.compression_standard, nullptr); + zf.commit(filename); } else { std::error_code ec; if (!trx::fs::exists(tmp_dir_name, ec) || !trx::fs::is_directory(tmp_dir_name, ec)) { - throw std::runtime_error("Temporary TRX directory does not exist: " + tmp_dir_name); + throw TrxIOError("Temporary TRX directory does not exist: " + tmp_dir_name); } if (trx::fs::exists(filename, ec) && trx::fs::is_directory(filename, ec)) { if (!options.overwrite_existing) { - throw std::runtime_error("Output directory already exists: " + filename); + throw TrxIOError("Output directory already exists: " + filename); } if (rm_dir(filename) != 0) { - throw std::runtime_error("Could not remove existing directory " + filename); + throw TrxIOError("Could not remove existing directory " + filename); } } trx::fs::path dest_path(filename); if (dest_path.has_parent_path()) { - std::error_code ec; - trx::fs::create_directories(dest_path.parent_path(), ec); - if (ec) { - throw std::runtime_error("Could not create output parent directory: " + dest_path.parent_path().string()); - } + mkdir_or_throw(dest_path.parent_path().string()); } copy_dir(tmp_dir_name, filename); ec.clear(); if (!trx::fs::exists(filename, ec) || !trx::fs::is_directory(filename, ec)) { - throw std::runtime_error("Failed to create output directory: " + filename); + throw TrxIOError("Failed to create output directory: " + filename); } const trx::fs::path header_path = dest_path / "header.json"; if (!trx::fs::exists(header_path)) { - throw std::runtime_error("Missing header.json in output directory: " + header_path.string()); + throw TrxFormatError("Missing header.json in output directory: " + header_path.string()); } } } @@ -1267,7 +1130,7 @@ template void TrxFile
::add_dps_from_text(const std::string &name, const std::string &dtype, const std::string &path) { std::ifstream input(path); if (!input.is_open()) { - throw std::runtime_error("Failed to open DPS text file: " + path); + throw TrxIOError("Failed to open DPS text file: " + path); } std::vector values; @@ -1276,7 +1139,7 @@ void TrxFile
::add_dps_from_text(const std::string &name, const std::string & values.push_back(value); } if (!input.eof() && input.fail()) { - throw std::runtime_error("Failed to parse DPS text file: " + path); + throw TrxFormatError("Failed to parse DPS text file: " + path); } add_dps_from_vector(name, dtype, values); @@ -1286,7 +1149,7 @@ template template void TrxFile
::add_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector &values) { if (name.empty()) { - throw std::invalid_argument("DPS name cannot be empty"); + throw TrxArgumentError("DPS name cannot be empty"); } std::string dtype_norm = dtype; @@ -1295,14 +1158,14 @@ void TrxFile
::add_dps_from_vector(const std::string &name, const std::string }); if (!trx::detail::_is_dtype_valid(dtype_norm)) { - throw std::invalid_argument("Unsupported DPS dtype: " + dtype); + throw TrxDTypeError("Unsupported DPS dtype: " + dtype); } if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { - throw std::invalid_argument("Unsupported DPS dtype: " + dtype); + throw TrxDTypeError("Unsupported DPS dtype: " + dtype); } if (this->_uncompressed_folder_handle.empty()) { - throw std::runtime_error("TRX file has no backing directory to store DPS data"); + throw TrxIOError("TRX file has no backing directory to store DPS data"); } size_t nb_streamlines = 0; @@ -1313,18 +1176,12 @@ void TrxFile
::add_dps_from_vector(const std::string &name, const std::string } if (values.size() != nb_streamlines) { - throw std::runtime_error("DPS values (" + std::to_string(values.size()) + ") do not match number of streamlines (" + + throw TrxFormatError("DPS values (" + std::to_string(values.size()) + ") do not match number of streamlines (" + std::to_string(nb_streamlines) + ")"); } std::string dps_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dps" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(dps_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dps_dirname); - } - } + mkdir_or_throw(dps_dirname); std::string dps_filename = dps_dirname + name + "." + dtype_norm; { @@ -1346,27 +1203,9 @@ void TrxFile
::add_dps_from_vector(const std::string &name, const std::string auto matrix = std::make_unique>(); matrix->mmap = trx::_create_memmap(dps_filename, shape, "w+", dtype_norm); - if (dtype_norm == "float16") { - auto *data = reinterpret_cast(matrix->mmap.data()); - Map> mapped(data, rows, cols); - new (&(matrix->_matrix)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } - } else if (dtype_norm == "float32") { - auto *data = reinterpret_cast(matrix->mmap.data()); - Map> mapped(data, rows, cols); - new (&(matrix->_matrix)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } - } else { - auto *data = reinterpret_cast(matrix->mmap.data()); - Map> mapped(data, rows, cols); - new (&(matrix->_matrix)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } + trx::detail::remap(matrix->_matrix, matrix->mmap.data(), rows, cols); + for (int i = 0; i < rows; ++i) { + matrix->_matrix(i, 0) = static_cast
(values[static_cast(i)]); } this->data_per_streamline[name] = std::move(matrix); @@ -1376,7 +1215,7 @@ template template void TrxFile
::add_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector &values) { if (name.empty()) { - throw std::invalid_argument("DPV name cannot be empty"); + throw TrxArgumentError("DPV name cannot be empty"); } std::string dtype_norm = dtype; @@ -1385,14 +1224,14 @@ void TrxFile
::add_dpv_from_vector(const std::string &name, const std::string }); if (!trx::detail::_is_dtype_valid(dtype_norm)) { - throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + throw TrxDTypeError("Unsupported DPV dtype: " + dtype); } if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { - throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + throw TrxDTypeError("Unsupported DPV dtype: " + dtype); } if (this->_uncompressed_folder_handle.empty()) { - throw std::runtime_error("TRX file has no backing directory to store DPV data"); + throw TrxIOError("TRX file has no backing directory to store DPV data"); } size_t nb_vertices = 0; @@ -1403,18 +1242,12 @@ void TrxFile
::add_dpv_from_vector(const std::string &name, const std::string } if (values.size() != nb_vertices) { - throw std::runtime_error("DPV values (" + std::to_string(values.size()) + ") do not match number of vertices (" + + throw TrxFormatError("DPV values (" + std::to_string(values.size()) + ") do not match number of vertices (" + std::to_string(nb_vertices) + ")"); } std::string dpv_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dpv" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(dpv_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpv_dirname); - } - } + mkdir_or_throw(dpv_dirname); std::string dpv_filename = dpv_dirname + name + "." + dtype_norm; { @@ -1436,33 +1269,14 @@ void TrxFile
::add_dpv_from_vector(const std::string &name, const std::string auto seq = std::make_unique>(); seq->mmap_pos = trx::_create_memmap(dpv_filename, shape, "w+", dtype_norm); - if (dtype_norm == "float16") { - auto *data = reinterpret_cast(seq->mmap_pos.data()); - Map> mapped(data, rows, cols); - new (&(seq->_data)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } - } else if (dtype_norm == "float32") { - auto *data = reinterpret_cast(seq->mmap_pos.data()); - Map> mapped(data, rows, cols); - new (&(seq->_data)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } - } else { - auto *data = reinterpret_cast(seq->mmap_pos.data()); - Map> mapped(data, rows, cols); - new (&(seq->_data)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } + trx::detail::remap(seq->_data, seq->mmap_pos.data(), rows, cols); + for (int i = 0; i < rows; ++i) { + seq->_data(i, 0) = static_cast
(values[static_cast(i)]); } if (this->streamlines && this->streamlines->_offsets.size() > 0) { - new (&(seq->_offsets)) Map>(this->streamlines->_offsets.data(), - int(this->streamlines->_offsets.rows()), - int(this->streamlines->_offsets.cols())); + trx::detail::remap(seq->_offsets, this->streamlines->_offsets.data(), int(this->streamlines->_offsets.rows()), + int(this->streamlines->_offsets.cols())); seq->_lengths = this->streamlines->_lengths; } @@ -1472,10 +1286,10 @@ void TrxFile
::add_dpv_from_vector(const std::string &name, const std::string template void TrxFile
::add_group_from_indices(const std::string &name, const std::vector &indices) { if (name.empty()) { - throw std::invalid_argument("Group name cannot be empty"); + throw TrxArgumentError("Group name cannot be empty"); } if (this->_uncompressed_folder_handle.empty()) { - throw std::runtime_error("TRX file has no backing directory to store groups"); + throw TrxIOError("TRX file has no backing directory to store groups"); } size_t nb_streamlines = 0; @@ -1487,18 +1301,12 @@ void TrxFile
::add_group_from_indices(const std::string &name, const std::vec for (const auto idx : indices) { if (idx >= nb_streamlines) { - throw std::runtime_error("Group index out of range: " + std::to_string(idx)); + throw TrxArgumentError("Group index out of range: " + std::to_string(idx)); } } std::string groups_dirname = this->_uncompressed_folder_handle + SEPARATOR + "groups" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(groups_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + groups_dirname); - } - } + mkdir_or_throw(groups_dirname); std::string group_filename = groups_dirname + name + ".uint32"; { @@ -1519,8 +1327,7 @@ void TrxFile
::add_group_from_indices(const std::string &name, const std::vec auto group = std::make_unique>(); group->mmap = trx::_create_memmap(group_filename, shape, "w+", "uint32"); - new (&(group->_matrix)) Map>( - reinterpret_cast(group->mmap.data()), std::get<0>(shape), std::get<1>(shape)); + trx::detail::remap(group->_matrix, group->mmap.data(), shape); for (int i = 0; i < rows; ++i) { group->_matrix(i, 0) = indices[static_cast(i)]; } @@ -1543,7 +1350,7 @@ inline TrxStream::TrxStream(std::string positions_dtype) : positions_dtype_(std: return static_cast(std::tolower(c)); }); if (positions_dtype_ != "float32" && positions_dtype_ != "float16") { - throw std::invalid_argument("TrxStream only supports float16/float32 positions for now"); + throw TrxArgumentError("TrxStream only supports float16/float32 positions for now"); } tmp_dir_ = make_temp_dir("trx_proto"); positions_path_ = tmp_dir_ + SEPARATOR + "positions.tmp"; @@ -1552,39 +1359,36 @@ inline TrxStream::TrxStream(std::string positions_dtype) : positions_dtype_(std: inline TrxStream::~TrxStream() { cleanup_tmp(); } -inline void TrxStream::set_metadata_mode(MetadataMode mode) { +inline TrxStream &TrxStream::set_metadata_mode(MetadataMode mode) { if (finalized_) { - throw std::runtime_error("Cannot adjust metadata mode after finalize"); + throw TrxArgumentError("Cannot adjust metadata mode after finalize"); } metadata_mode_ = mode; + return *this; } -inline void TrxStream::set_metadata_buffer_max_bytes(std::size_t max_bytes) { +inline TrxStream &TrxStream::set_metadata_buffer_max_bytes(std::size_t max_bytes) { if (finalized_) { - throw std::runtime_error("Cannot adjust metadata buffer after finalize"); + throw TrxArgumentError("Cannot adjust metadata buffer after finalize"); } metadata_buffer_max_bytes_ = max_bytes; + return *this; } inline void TrxStream::ensure_positions_stream() { if (!positions_out_.is_open()) { positions_out_.open(positions_path_, std::ios::binary | std::ios::out | std::ios::trunc); if (!positions_out_.is_open()) { - throw std::runtime_error("Failed to open TrxStream temp positions file: " + positions_path_); + throw TrxIOError("Failed to open TrxStream temp positions file: " + positions_path_); } } } inline void TrxStream::ensure_metadata_dir(const std::string &subdir) { if (tmp_dir_.empty()) { - throw std::runtime_error("TrxStream temp directory not initialized"); - } - const std::string dir = tmp_dir_ + SEPARATOR + subdir + SEPARATOR; - std::error_code ec; - trx::fs::create_directories(dir, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dir); + throw TrxIOError("TrxStream temp directory not initialized"); } + mkdir_or_throw(tmp_dir_ + SEPARATOR + subdir + SEPARATOR); } inline void TrxStream::flush_positions_buffer() { @@ -1597,7 +1401,7 @@ inline void TrxStream::flush_positions_buffer() { positions_out_.write(reinterpret_cast(positions_buffer_half_.data()), static_cast(byte_count)); if (!positions_out_) { - throw std::runtime_error("Failed to write TrxStream positions buffer"); + throw TrxIOError("Failed to write TrxStream positions buffer"); } positions_buffer_half_.clear(); return; @@ -1611,7 +1415,7 @@ inline void TrxStream::flush_positions_buffer() { positions_out_.write(reinterpret_cast(positions_buffer_float_.data()), static_cast(byte_count)); if (!positions_out_) { - throw std::runtime_error("Failed to write TrxStream positions buffer"); + throw TrxIOError("Failed to write TrxStream positions buffer"); } positions_buffer_float_.clear(); } @@ -1630,7 +1434,7 @@ inline void TrxStream::cleanup_tmp() { inline void TrxStream::push_streamline(const float *xyz, size_t point_count) { if (finalized_) { - throw std::runtime_error("TrxStream already finalized"); + throw TrxArgumentError("TrxStream already finalized"); } if (point_count == 0) { lengths_.push_back(0); @@ -1647,13 +1451,13 @@ inline void TrxStream::push_streamline(const float *xyz, size_t point_count) { const size_t byte_count = tmp.size() * sizeof(half); positions_out_.write(reinterpret_cast(tmp.data()), static_cast(byte_count)); if (!positions_out_) { - throw std::runtime_error("Failed to write TrxStream positions"); + throw TrxIOError("Failed to write TrxStream positions"); } } else { const size_t byte_count = point_count * 3 * sizeof(float); positions_out_.write(reinterpret_cast(xyz), static_cast(byte_count)); if (!positions_out_) { - throw std::runtime_error("Failed to write TrxStream positions"); + throw TrxIOError("Failed to write TrxStream positions"); } } } else { @@ -1679,7 +1483,7 @@ inline void TrxStream::push_streamline(const float *xyz, size_t point_count) { inline void TrxStream::push_streamline(const std::vector &xyz_flat) { if (xyz_flat.size() % 3 != 0) { - throw std::invalid_argument("TrxStream streamline buffer must be a multiple of 3"); + throw TrxArgumentError("TrxStream streamline buffer must be a multiple of 3"); } push_streamline(xyz_flat.data(), xyz_flat.size() / 3); } @@ -1699,7 +1503,7 @@ inline void TrxStream::push_streamline(const std::vector> & push_streamline(xyz_flat); } -inline void TrxStream::set_voxel_to_rasmm(const Eigen::Matrix4f &affine) { +inline TrxStream &TrxStream::set_voxel_to_rasmm(const Eigen::Matrix4f &affine) { std::vector> matrix(4, std::vector(4, 0.0f)); for (int i = 0; i < 4; ++i) { for (int j = 0; j < 4; ++j) { @@ -1707,34 +1511,36 @@ inline void TrxStream::set_voxel_to_rasmm(const Eigen::Matrix4f &affine) { } } header = _json_set(header, "VOXEL_TO_RASMM", matrix); + return *this; } -inline void TrxStream::set_dimensions(const std::array &dims) { +inline TrxStream &TrxStream::set_dimensions(const std::array &dims) { header = _json_set(header, "DIMENSIONS", std::vector{dims[0], dims[1], dims[2]}); + return *this; } template inline void TrxStream::push_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector &values) { if (name.empty()) { - throw std::invalid_argument("DPS name cannot be empty"); + throw TrxArgumentError("DPS name cannot be empty"); } std::string dtype_norm = dtype; std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); if (!trx::detail::_is_dtype_valid(dtype_norm)) { - throw std::invalid_argument("Unsupported DPS dtype: " + dtype); + throw TrxDTypeError("Unsupported DPS dtype: " + dtype); } if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { - throw std::invalid_argument("Unsupported DPS dtype: " + dtype); + throw TrxDTypeError("Unsupported DPS dtype: " + dtype); } if (metadata_mode_ == MetadataMode::OnDisk) { ensure_metadata_dir("dps"); const std::string filename = tmp_dir_ + SEPARATOR + "dps" + SEPARATOR + name + "." + dtype_norm; std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to open DPS file: " + filename); + throw TrxIOError("Failed to open DPS file: " + filename); } if (dtype_norm == "float16") { const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(half)); @@ -1796,24 +1602,24 @@ template inline void TrxStream::push_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector &values) { if (name.empty()) { - throw std::invalid_argument("DPV name cannot be empty"); + throw TrxArgumentError("DPV name cannot be empty"); } std::string dtype_norm = dtype; std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); if (!trx::detail::_is_dtype_valid(dtype_norm)) { - throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + throw TrxDTypeError("Unsupported DPV dtype: " + dtype); } if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { - throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + throw TrxDTypeError("Unsupported DPV dtype: " + dtype); } if (metadata_mode_ == MetadataMode::OnDisk) { ensure_metadata_dir("dpv"); const std::string filename = tmp_dir_ + SEPARATOR + "dpv" + SEPARATOR + name + "." + dtype_norm; std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to open DPV file: " + filename); + throw TrxIOError("Failed to open DPV file: " + filename); } if (dtype_norm == "float16") { const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(half)); @@ -1873,7 +1679,7 @@ TrxStream::push_dpv_from_vector(const std::string &name, const std::string &dtyp inline void TrxStream::set_positions_buffer_max_bytes(std::size_t max_bytes) { if (finalized_) { - throw std::runtime_error("Cannot adjust buffer after finalize"); + throw TrxArgumentError("Cannot adjust buffer after finalize"); } if (max_bytes == 0) { positions_buffer_max_entries_ = 0; @@ -1893,14 +1699,14 @@ inline void TrxStream::set_positions_buffer_max_bytes(std::size_t max_bytes) { inline void TrxStream::push_group_from_indices(const std::string &name, const std::vector &indices) { if (name.empty()) { - throw std::invalid_argument("Group name cannot be empty"); + throw TrxArgumentError("Group name cannot be empty"); } if (metadata_mode_ == MetadataMode::OnDisk) { ensure_metadata_dir("groups"); const std::string filename = tmp_dir_ + SEPARATOR + "groups" + SEPARATOR + name + ".uint32"; std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to open group file: " + filename); + throw TrxIOError("Failed to open group file: " + filename); } const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(uint32_t)); size_t offset = 0; @@ -1919,7 +1725,7 @@ inline void TrxStream::push_group_from_indices(const std::string &name, const st template void TrxStream::finalize(const std::string &filename, zip_uint32_t compression_standard) { if (finalized_) { - throw std::runtime_error("TrxStream already finalized"); + throw TrxArgumentError("TrxStream already finalized"); } finalized_ = true; @@ -1951,14 +1757,14 @@ template void TrxStream::finalize(const std::string &filename, zip std::ifstream in(positions_path_, std::ios::binary); if (!in.is_open()) { - throw std::runtime_error("Failed to open TrxStream temp positions file for read: " + positions_path_); + throw TrxIOError("Failed to open TrxStream temp positions file for read: " + positions_path_); } for (size_t i = 0; i < nb_vertices; ++i) { if (positions_dtype_ == "float16") { half xyz[3]; in.read(reinterpret_cast(xyz), sizeof(xyz)); if (!in) { - throw std::runtime_error("Failed to read TrxStream positions"); + throw TrxIOError("Failed to read TrxStream positions"); } positions(static_cast(i), 0) = static_cast
(xyz[0]); positions(static_cast(i), 1) = static_cast
(xyz[1]); @@ -1967,7 +1773,7 @@ template void TrxStream::finalize(const std::string &filename, zip float xyz[3]; in.read(reinterpret_cast(xyz), sizeof(xyz)); if (!in) { - throw std::runtime_error("Failed to read TrxStream positions"); + throw TrxIOError("Failed to read TrxStream positions"); } positions(static_cast(i), 0) = static_cast
(xyz[0]); positions(static_cast(i), 1) = static_cast
(xyz[1]); @@ -1996,7 +1802,7 @@ template void TrxStream::finalize(const std::string &filename, zip std::error_code copy_ec; trx::fs::copy_file(meta.absolute_path, dest, trx::fs::copy_options::overwrite_existing, copy_ec); if (copy_ec) { - throw std::runtime_error("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest); + throw TrxIOError("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest); } } } @@ -2027,7 +1833,7 @@ inline void TrxStream::finalize(const std::string &filename, inline void TrxStream::finalize(const std::string &filename, const TrxSaveOptions &options) { if (options.mode == TrxSaveMode::Directory) { if (finalized_) { - throw std::runtime_error("TrxStream already finalized"); + throw TrxArgumentError("TrxStream already finalized"); } if (options.overwrite_existing) { finalize_directory(filename); @@ -2048,7 +1854,7 @@ inline void TrxStream::finalize(const std::string &filename, const TrxSaveOption inline void TrxStream::finalize_directory_impl(const std::string &directory, bool remove_existing) { if (finalized_) { - throw std::runtime_error("TrxStream already finalized"); + throw TrxArgumentError("TrxStream already finalized"); } finalized_ = true; @@ -2069,10 +1875,7 @@ inline void TrxStream::finalize_directory_impl(const std::string &directory, boo // Create directory if it doesn't exist if (!trx::fs::exists(directory, ec)) { - trx::fs::create_directories(directory, ec); - if (ec) { - throw std::runtime_error("Failed to create output directory: " + directory); - } + mkdir_or_throw(directory); } ec.clear(); @@ -2082,7 +1885,7 @@ inline void TrxStream::finalize_directory_impl(const std::string &directory, boo const std::string header_path = directory + SEPARATOR + "header.json"; std::ofstream out_header(header_path, std::ios::out | std::ios::trunc); if (!out_header.is_open()) { - throw std::runtime_error("Failed to write header.json to: " + header_path); + throw TrxIOError("Failed to write header.json to: " + header_path); } out_header << header_out.dump() << std::endl; out_header.close(); @@ -2094,14 +1897,14 @@ inline void TrxStream::finalize_directory_impl(const std::string &directory, boo ec.clear(); trx::fs::copy_file(positions_path_, positions_dst, trx::fs::copy_options::overwrite_existing, ec); if (ec) { - throw std::runtime_error("Failed to copy positions file to: " + positions_dst); + throw TrxIOError("Failed to copy positions file to: " + positions_dst); } } const std::string offsets_dst = directory + SEPARATOR + "offsets.uint64"; std::ofstream offsets_out(offsets_dst, std::ios::binary | std::ios::out | std::ios::trunc); if (!offsets_out.is_open()) { - throw std::runtime_error("Failed to open offsets file for write: " + offsets_dst); + throw TrxIOError("Failed to open offsets file for write: " + offsets_dst); } uint64_t offset = 0; offsets_out.write(reinterpret_cast(&offset), sizeof(offset)); @@ -2115,7 +1918,7 @@ inline void TrxStream::finalize_directory_impl(const std::string &directory, boo auto write_field_values = [&](const std::string &path, const FieldValues &values) { std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to open metadata file: " + path); + throw TrxIOError("Failed to open metadata file: " + path); } const size_t count = values.values.size(); if (values.dtype == "float16") { @@ -2161,7 +1964,7 @@ inline void TrxStream::finalize_directory_impl(const std::string &directory, boo idx += n; } } else { - throw std::runtime_error("Unsupported metadata dtype: " + values.dtype); + throw TrxDTypeError("Unsupported metadata dtype: " + values.dtype); } out.close(); }; @@ -2177,31 +1980,31 @@ inline void TrxStream::finalize_directory_impl(const std::string &directory, boo std::error_code copy_ec; trx::fs::copy_file(meta.absolute_path, dest, trx::fs::copy_options::overwrite_existing, copy_ec); if (copy_ec) { - throw std::runtime_error("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest); + throw TrxIOError("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest); } } } else { if (!dps_.empty()) { - trx::fs::create_directories(directory + SEPARATOR + "dps", ec); + mkdir_or_throw(directory + SEPARATOR + "dps"); for (const auto &kv : dps_) { const std::string path = directory + SEPARATOR + "dps" + SEPARATOR + kv.first + "." + kv.second.dtype; write_field_values(path, kv.second); } } if (!dpv_.empty()) { - trx::fs::create_directories(directory + SEPARATOR + "dpv", ec); + mkdir_or_throw(directory + SEPARATOR + "dpv"); for (const auto &kv : dpv_) { const std::string path = directory + SEPARATOR + "dpv" + SEPARATOR + kv.first + "." + kv.second.dtype; write_field_values(path, kv.second); } } if (!groups_.empty()) { - trx::fs::create_directories(directory + SEPARATOR + "groups", ec); + mkdir_or_throw(directory + SEPARATOR + "groups"); for (const auto &kv : groups_) { const std::string path = directory + SEPARATOR + "groups" + SEPARATOR + kv.first + ".uint32"; std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to open group file: " + path); + throw TrxIOError("Failed to open group file: " + path); } if (!kv.second.empty()) { out.write(reinterpret_cast(kv.second.data()), @@ -2226,7 +2029,7 @@ inline void TrxStream::finalize_directory_persistent(const std::string &director template void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &dtype, const std::string &path) { if (name.empty()) { - throw std::invalid_argument("DPV name cannot be empty"); + throw TrxArgumentError("DPV name cannot be empty"); } std::string dtype_norm = dtype; @@ -2235,17 +2038,17 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d }); if (!trx::detail::_is_dtype_valid(dtype_norm)) { - throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + throw TrxDTypeError("Unsupported DPV dtype: " + dtype); } if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { - throw std::invalid_argument("Unsupported DPV dtype for TSF input: " + dtype); + throw TrxDTypeError("Unsupported DPV dtype for TSF input: " + dtype); } if (!this->streamlines) { - throw std::runtime_error("TRX file has no streamlines to attach DPV data"); + throw TrxFormatError("TRX file has no streamlines to attach DPV data"); } if (this->_uncompressed_folder_handle.empty()) { - throw std::runtime_error("TRX file has no backing directory to store DPV data"); + throw TrxIOError("TRX file has no backing directory to store DPV data"); } const auto &lengths = this->streamlines->_lengths; @@ -2254,7 +2057,7 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d std::ifstream input(path); if (!input.is_open()) { - throw std::runtime_error("Failed to open TSF file: " + path); + throw TrxIOError("Failed to open TSF file: " + path); } auto trim = [](std::string note) { @@ -2299,14 +2102,14 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d } } if (!found_end) { - throw std::runtime_error("Failed to parse TSF header: missing END"); + throw TrxFormatError("Failed to parse TSF header: missing END"); } } else { input.clear(); input.seekg(start_pos); } } else { - throw std::runtime_error("Failed to parse TSF file: " + path); + throw TrxFormatError("Failed to parse TSF file: " + path); } std::vector values; @@ -2317,7 +2120,7 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d if (binary_mode) { if (datatype != "Float32LE" && datatype != "Float32BE" && datatype != "Float64LE" && datatype != "Float64BE") { - throw std::runtime_error("Unsupported TSF datatype: " + datatype); + throw TrxDTypeError("Unsupported TSF datatype: " + datatype); } auto is_little_endian = []() { @@ -2365,7 +2168,7 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d } if (std::isnan(value)) { if (current_vertices != expected_vertices) { - throw std::runtime_error("TSF streamline length does not match TRX streamlines"); + throw TrxFormatError("TSF streamline length does not match TRX streamlines"); } if (streamline_index + 1 < nb_streamlines) { ++streamline_index; @@ -2386,7 +2189,7 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d }); if (token_norm.rfind("nan", 0) == 0) { if (current_vertices != expected_vertices) { - throw std::runtime_error("TSF streamline length does not match TRX streamlines"); + throw TrxFormatError("TSF streamline length does not match TRX streamlines"); } if (streamline_index + 1 < nb_streamlines) { ++streamline_index; @@ -2403,17 +2206,17 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d size_t idx = 0; value = std::stod(token, &idx); if (idx != token.size()) { - throw std::invalid_argument("invalid token"); + throw TrxArgumentError("invalid token"); } } catch (const std::exception &) { - throw std::runtime_error("Failed to parse TSF file: " + path); + throw TrxFormatError("Failed to parse TSF file: " + path); } if (std::isinf(value)) { break; } if (std::isnan(value)) { if (current_vertices != expected_vertices) { - throw std::runtime_error("TSF streamline length does not match TRX streamlines"); + throw TrxFormatError("TSF streamline length does not match TRX streamlines"); } if (streamline_index + 1 < nb_streamlines) { ++streamline_index; @@ -2426,27 +2229,21 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d ++current_vertices; } if (!input.eof() && input.fail()) { - throw std::runtime_error("Failed to parse TSF file: " + path); + throw TrxFormatError("Failed to parse TSF file: " + path); } } if (nb_streamlines > 0) { if (streamline_index != nb_streamlines - 1 || current_vertices != expected_vertices) { - throw std::runtime_error("TSF streamline count does not match TRX streamlines"); + throw TrxFormatError("TSF streamline count does not match TRX streamlines"); } } if (values.size() != nb_vertices) { - throw std::runtime_error("TSF values (" + std::to_string(values.size()) + ") do not match number of vertices (" + + throw TrxFormatError("TSF values (" + std::to_string(values.size()) + ") do not match number of vertices (" + std::to_string(nb_vertices) + ")"); } std::string dpv_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dpv" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(dpv_dirname, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpv_dirname); - } - } + mkdir_or_throw(dpv_dirname); std::string dpv_filename = dpv_dirname + name + "." + dtype_norm; { @@ -2468,32 +2265,14 @@ void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &d auto seq = std::make_unique>(); seq->mmap_pos = trx::_create_memmap(dpv_filename, shape, "w+", dtype_norm); - if (dtype_norm == "float16") { - auto *data = reinterpret_cast(seq->mmap_pos.data()); - Map> mapped(data, rows, cols); - new (&(seq->_data)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } - } else if (dtype_norm == "float32") { - auto *data = reinterpret_cast(seq->mmap_pos.data()); - Map> mapped(data, rows, cols); - new (&(seq->_data)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } - } else { - auto *data = reinterpret_cast(seq->mmap_pos.data()); - Map> mapped(data, rows, cols); - new (&(seq->_data)) Map>(data, rows, cols); - for (int i = 0; i < rows; ++i) { - mapped(i, 0) = static_cast(values[static_cast(i)]); - } + trx::detail::remap(seq->_data, seq->mmap_pos.data(), rows, cols); + for (int i = 0; i < rows; ++i) { + seq->_data(i, 0) = static_cast
(values[static_cast(i)]); } - new (&(seq->_offsets)) Map>(this->streamlines->_offsets.data(), - static_cast(this->streamlines->_offsets.rows()), - static_cast(this->streamlines->_offsets.cols())); + trx::detail::remap(seq->_offsets, this->streamlines->_offsets.data(), + static_cast(this->streamlines->_offsets.rows()), + static_cast(this->streamlines->_offsets.cols())); seq->_lengths = this->streamlines->_lengths; this->data_per_vertex[name] = std::move(seq); @@ -2505,10 +2284,10 @@ void TrxFile
::export_dpv_to_tsf(const std::string &name, const std::string ×tamp, const std::string &dtype) const { if (name.empty()) { - throw std::invalid_argument("DPV name cannot be empty"); + throw TrxArgumentError("DPV name cannot be empty"); } if (timestamp.empty()) { - throw std::invalid_argument("TSF timestamp cannot be empty"); + throw TrxArgumentError("TSF timestamp cannot be empty"); } std::string dtype_norm = dtype; @@ -2517,34 +2296,34 @@ void TrxFile
::export_dpv_to_tsf(const std::string &name, }); if (!trx::detail::_is_dtype_valid(dtype_norm)) { - throw std::invalid_argument("Unsupported TSF dtype: " + dtype); + throw TrxDTypeError("Unsupported TSF dtype: " + dtype); } if (dtype_norm != "float32" && dtype_norm != "float64") { - throw std::invalid_argument("Unsupported TSF dtype for output: " + dtype); + throw TrxDTypeError("Unsupported TSF dtype for output: " + dtype); } if (!this->streamlines) { - throw std::runtime_error("TRX file has no streamlines to export DPV data"); + throw TrxFormatError("TRX file has no streamlines to export DPV data"); } const auto dpv_it = this->data_per_vertex.find(name); if (dpv_it == this->data_per_vertex.end()) { - throw std::runtime_error("DPV entry not found: " + name); + throw TrxFormatError("DPV entry not found: " + name); } const auto *seq = dpv_it->second.get(); if (!seq) { - throw std::runtime_error("DPV entry is null: " + name); + throw TrxFormatError("DPV entry is null: " + name); } if (seq->_data.cols() != 1) { - throw std::runtime_error("DPV must be 1D to export as TSF: " + name); + throw TrxFormatError("DPV must be 1D to export as TSF: " + name); } const auto &lengths = this->streamlines->_lengths; const size_t nb_streamlines = static_cast(lengths.size()); const size_t nb_vertices = static_cast(seq->_data.rows()); if (nb_vertices != static_cast(this->streamlines->_data.rows())) { - throw std::runtime_error("DPV vertex count does not match streamlines data"); + throw TrxFormatError("DPV vertex count does not match streamlines data"); } const auto is_little_endian = []() { @@ -2582,7 +2361,7 @@ void TrxFile
::export_dpv_to_tsf(const std::string &name, std::ofstream out(path, std::ios::binary | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to open TSF file for writing: " + path); + throw TrxIOError("Failed to open TSF file for writing: " + path); } out.write(header.data(), static_cast(header.size())); const size_t pad = (4 - (header.size() % 4)) % 4; @@ -2606,13 +2385,13 @@ void TrxFile
::export_dpv_to_tsf(const std::string &name, for (size_t s = 0; s < nb_streamlines; ++s) { const uint32_t len = lengths(static_cast(s)); if (offset > total_vertices) { - throw std::runtime_error("DPV length metadata exceeds vertex count"); + throw TrxFormatError("DPV length metadata exceeds vertex count"); } if (len > std::numeric_limits::max() - offset) { - throw std::runtime_error("DPV length metadata exceeds vertex count"); + throw TrxFormatError("DPV length metadata exceeds vertex count"); } if (offset + static_cast(len) > total_vertices) { - throw std::runtime_error("DPV length metadata exceeds vertex count"); + throw TrxFormatError("DPV length metadata exceeds vertex count"); } offset += static_cast(len); } @@ -2622,7 +2401,7 @@ void TrxFile
::export_dpv_to_tsf(const std::string &name, for (uint32_t i = 0; i < len; ++i) { const size_t idx = offset + static_cast(i); if (idx > static_cast(std::numeric_limits::max())) { - throw std::runtime_error("DPV length metadata exceeds vertex count"); + throw TrxFormatError("DPV length metadata exceeds vertex count"); } write_value(static_cast(seq->_data(static_cast(idx), 0))); } @@ -2634,7 +2413,7 @@ void TrxFile
::export_dpv_to_tsf(const std::string &name, write_value(std::numeric_limits::infinity()); if (!out.good()) { - throw std::runtime_error("Failed to write TSF file: " + path); + throw TrxIOError("Failed to write TSF file: " + path); } } @@ -2721,24 +2500,16 @@ std::unique_ptr> TrxFile
::query_aabb( const std::array &min_corner, const std::array &max_corner, const std::vector> *precomputed_aabbs, - bool build_cache_for_result) const { + bool build_cache_for_result, + size_t max_streamlines, + uint32_t rng_seed) const { if (!this->streamlines) { - auto empty = std::make_unique>(); - empty->header = _json_set(this->header, "NB_VERTICES", 0); - empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); - return empty; + return this->make_empty_like(); } - size_t nb_streamlines = 0; - if (this->streamlines->_offsets.size() > 0) { - nb_streamlines = static_cast(this->streamlines->_offsets.size() - 1); - } else if (this->streamlines->_lengths.size() > 0) { - nb_streamlines = static_cast(this->streamlines->_lengths.size()); - } else { - auto empty = std::make_unique>(); - empty->header = _json_set(this->header, "NB_VERTICES", 0); - empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); - return empty; + const size_t nb_streamlines = this->num_streamlines(); + if (nb_streamlines == 0) { + return this->make_empty_like(); } std::vector> aabbs_local; @@ -2746,7 +2517,7 @@ std::unique_ptr> TrxFile
::query_aabb( ? *precomputed_aabbs : (!this->aabb_cache_.empty() ? this->aabb_cache_ : (aabbs_local = this->build_streamline_aabbs())); if (aabbs.size() != nb_streamlines) { - throw std::invalid_argument("AABB size does not match streamlines count"); + throw TrxArgumentError("AABB size does not match streamlines count"); } const float min_x = min_corner[0]; @@ -2775,6 +2546,14 @@ std::unique_ptr> TrxFile
::query_aabb( } } + if (max_streamlines > 0 && selected.size() > max_streamlines) { + std::mt19937 rng(rng_seed); + std::shuffle(selected.begin(), selected.end(), rng); + selected.resize(max_streamlines); + // Re-sort by index for sequential memory access in subset_streamlines. + std::sort(selected.begin(), selected.end()); + } + return this->subset_streamlines(selected, build_cache_for_result); } @@ -2804,9 +2583,9 @@ const ArraySequence
*TrxFile
::get_dpv(const std::string &name) const { template std::vector> TrxFile
::get_streamline(size_t streamline_index) const { if (!this->streamlines || this->streamlines->_offsets.size() == 0) { - throw std::runtime_error("TRX streamlines are not available"); + throw TrxFormatError("TRX streamlines are not available"); } - const size_t n_streamlines = static_cast(this->streamlines->_offsets.size() - 1); + const size_t n_streamlines = this->num_streamlines(); if (streamline_index >= n_streamlines) { throw std::out_of_range("Streamline index out of range"); } @@ -2833,7 +2612,7 @@ void TrxFile
::for_each_streamline(Fn &&fn) const { if (!this->streamlines || this->streamlines->_offsets.size() == 0) { return; } - const size_t n_streamlines = static_cast(this->streamlines->_offsets.size() - 1); + const size_t n_streamlines = this->num_streamlines(); for (size_t i = 0; i < n_streamlines; ++i) { const uint64_t start = static_cast(this->streamlines->_offsets(static_cast(i), 0)); const uint64_t end = static_cast(this->streamlines->_offsets(static_cast(i + 1), 0)); @@ -2850,56 +2629,44 @@ void TrxFile
::add_dpg_from_vector(const std::string &group, int rows, int cols) { if (group.empty()) { - throw std::invalid_argument("DPG group cannot be empty"); + throw TrxArgumentError("DPG group cannot be empty"); } if (name.empty()) { - throw std::invalid_argument("DPG name cannot be empty"); + throw TrxArgumentError("DPG name cannot be empty"); } std::string dtype_norm = dtype; std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); if (!trx::detail::_is_dtype_valid(dtype_norm)) { - throw std::invalid_argument("Unsupported DPG dtype: " + dtype); + throw TrxDTypeError("Unsupported DPG dtype: " + dtype); } if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { - throw std::invalid_argument("Unsupported DPG dtype: " + dtype); + throw TrxDTypeError("Unsupported DPG dtype: " + dtype); } if (this->_uncompressed_folder_handle.empty()) { - throw std::runtime_error("TRX file has no backing directory to store DPG data"); + throw TrxIOError("TRX file has no backing directory to store DPG data"); } if (rows <= 0) { - throw std::invalid_argument("DPG rows must be positive"); + throw TrxArgumentError("DPG rows must be positive"); } if (cols < 0) { if (values.size() % static_cast(rows) != 0) { - throw std::invalid_argument("DPG values size does not match rows"); + throw TrxArgumentError("DPG values size does not match rows"); } cols = static_cast(values.size() / static_cast(rows)); } if (cols <= 0) { - throw std::invalid_argument("DPG cols must be positive"); + throw TrxArgumentError("DPG cols must be positive"); } if (static_cast(rows) * static_cast(cols) != values.size()) { - throw std::invalid_argument("DPG values size does not match rows*cols"); + throw TrxArgumentError("DPG values size does not match rows*cols"); } std::string dpg_dir = this->_uncompressed_folder_handle + SEPARATOR + "dpg" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(dpg_dir, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpg_dir); - } - } + mkdir_or_throw(dpg_dir); std::string dpg_subdir = dpg_dir + group; - { - std::error_code ec; - trx::fs::create_directories(dpg_subdir, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpg_subdir); - } - } + mkdir_or_throw(dpg_subdir); std::string dpg_filename = dpg_subdir + SEPARATOR + name + "." + dtype_norm; { @@ -2916,27 +2683,9 @@ void TrxFile
::add_dpg_from_vector(const std::string &group, group_map[name] = std::make_unique>(); group_map[name]->mmap = _create_memmap(dpg_filename, shape, "w+", dtype_norm); - if (dtype_norm == "float16") { - auto *data = reinterpret_cast(group_map[name]->mmap.data()); - Map> mapped(data, rows, cols); - new (&(group_map[name]->_matrix)) Map>(data, rows, cols); - for (int i = 0; i < rows * cols; ++i) { - data[i] = static_cast(values[static_cast(i)]); - } - } else if (dtype_norm == "float32") { - auto *data = reinterpret_cast(group_map[name]->mmap.data()); - Map> mapped(data, rows, cols); - new (&(group_map[name]->_matrix)) Map>(data, rows, cols); - for (int i = 0; i < rows * cols; ++i) { - data[i] = static_cast(values[static_cast(i)]); - } - } else { - auto *data = reinterpret_cast(group_map[name]->mmap.data()); - Map> mapped(data, rows, cols); - new (&(group_map[name]->_matrix)) Map>(data, rows, cols); - for (int i = 0; i < rows * cols; ++i) { - data[i] = static_cast(values[static_cast(i)]); - } + trx::detail::remap(group_map[name]->_matrix, group_map[name]->mmap.data(), rows, cols); + for (int i = 0; i < rows * cols; ++i) { + group_map[name]->_matrix(i) = static_cast
(values[static_cast(i)]); } } @@ -2947,7 +2696,7 @@ void TrxFile
::add_dpg_from_matrix(const std::string &group, const std::string &dtype, const Eigen::MatrixBase &matrix) { if (matrix.size() == 0) { - throw std::invalid_argument("DPG matrix cannot be empty"); + throw TrxArgumentError("DPG matrix cannot be empty"); } std::vector values; values.reserve(static_cast(matrix.size())); @@ -3018,10 +2767,7 @@ template std::unique_ptr> TrxFile
::subset_streamlines(const std::vector &streamline_ids, bool build_cache_for_result) const { if (!this->streamlines) { - auto empty = std::make_unique>(); - empty->header = _json_set(this->header, "NB_VERTICES", 0); - empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); - return empty; + return this->make_empty_like(); } std::vector offsets; @@ -3038,18 +2784,12 @@ std::unique_ptr> TrxFile
::subset_streamlines(const std::vector(this->streamlines->_lengths(static_cast(i))); } } else { - auto empty = std::make_unique>(); - empty->header = _json_set(this->header, "NB_VERTICES", 0); - empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); - return empty; + return this->make_empty_like(); } const size_t nb_streamlines = offsets.size() > 0 ? offsets.size() - 1 : 0; if (nb_streamlines == 0) { - auto empty = std::make_unique>(); - empty->header = _json_set(this->header, "NB_VERTICES", 0); - empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); - return empty; + return this->make_empty_like(); } std::vector selected; @@ -3057,7 +2797,7 @@ std::unique_ptr> TrxFile
::subset_streamlines(const std::vector seen(nb_streamlines, 0); for (uint32_t id : streamline_ids) { if (id >= nb_streamlines) { - throw std::invalid_argument("Streamline id out of range"); + throw TrxArgumentError("Streamline id out of range"); } if (!seen[id]) { selected.push_back(id); @@ -3066,10 +2806,7 @@ std::unique_ptr> TrxFile
::subset_streamlines(const std::vector>(); - empty->header = _json_set(this->header, "NB_VERTICES", 0); - empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); - return empty; + return this->make_empty_like(); } std::vector old_to_new(nb_streamlines, -1); @@ -3163,13 +2900,7 @@ std::unique_ptr> TrxFile
::subset_streamlines(const std::vectordata_per_group.empty() && !out->groups.empty()) { std::string dpg_dir = out->_uncompressed_folder_handle + SEPARATOR + "dpg" + SEPARATOR; - { - std::error_code ec; - trx::fs::create_directories(dpg_dir, ec); - if (ec) { - throw std::runtime_error("Could not create directory " + dpg_dir); - } - } + mkdir_or_throw(dpg_dir); for (const auto &group_kv : out->groups) { const std::string &group_name = group_kv.first; @@ -3179,13 +2910,7 @@ std::unique_ptr> TrxFile
::subset_streamlines(const std::vectordata_per_group.find(group_name) == out->data_per_group.end()) { out->data_per_group.emplace(group_name, std::map>>{}); @@ -3206,19 +2931,9 @@ std::unique_ptr> TrxFile
::subset_streamlines(const std::vectordata_per_group[group_name][field_name]->mmap = _create_memmap(dpg_filename, dpg_shape, "w+", dpg_dtype); - if (dpg_dtype.compare("float16") == 0) { - new (&(out->data_per_group[group_name][field_name]->_matrix)) Map>( - reinterpret_cast(out->data_per_group[group_name][field_name]->mmap.data()), - std::get<0>(dpg_shape), std::get<1>(dpg_shape)); - } else if (dpg_dtype.compare("float32") == 0) { - new (&(out->data_per_group[group_name][field_name]->_matrix)) Map>( - reinterpret_cast(out->data_per_group[group_name][field_name]->mmap.data()), - std::get<0>(dpg_shape), std::get<1>(dpg_shape)); - } else { - new (&(out->data_per_group[group_name][field_name]->_matrix)) Map>( - reinterpret_cast(out->data_per_group[group_name][field_name]->mmap.data()), - std::get<0>(dpg_shape), std::get<1>(dpg_shape)); - } + trx::detail::remap(out->data_per_group[group_name][field_name]->_matrix, + out->data_per_group[group_name][field_name]->mmap.data(), + std::get<0>(dpg_shape), std::get<1>(dpg_shape)); for (int i = 0; i < out->data_per_group[group_name][field_name]->_matrix.rows(); ++i) { for (int j = 0; j < out->data_per_group[group_name][field_name]->_matrix.cols(); ++j) { diff --git a/src/detail/dtype_helpers.cpp b/src/detail/dtype_helpers.cpp index 197e5b7..8bf056c 100644 --- a/src/detail/dtype_helpers.cpp +++ b/src/detail/dtype_helpers.cpp @@ -61,22 +61,22 @@ std::string _get_dtype(const std::string &dtype) { } bool _is_dtype_valid(const std::string &ext) { - if (std::find(::trx::dtypes.begin(), ::trx::dtypes.end(), ext) != ::trx::dtypes.end()) + if (std::find(dtypes.begin(), dtypes.end(), ext) != dtypes.end()) return true; return false; } std::tuple _split_ext_with_dimensionality(const std::string &filename) { - std::string base = ::trx::path_basename(filename); + std::string base = path_basename(filename); const size_t num_splits = std::count(base.begin(), base.end(), '.'); int dim = 0; if (num_splits != 1 && num_splits != 2) { - throw std::invalid_argument("Invalid filename"); + throw TrxFormatError("Invalid filename"); } - const std::string ext = ::trx::get_ext(base); + const std::string ext = get_ext(base); base = base.substr(0, base.length() - ext.length() - 1); @@ -91,8 +91,7 @@ std::tuple _split_ext_with_dimensionality(const s const bool is_valid = _is_dtype_valid(ext); if (!is_valid) { - // TODO: make formatted string and include provided extension name - throw std::invalid_argument("Unsupported file extension"); + throw TrxDTypeError("Unsupported file extension: " + ext); } std::tuple output{base, dim, ext}; diff --git a/src/trx.cpp b/src/trx.cpp index 67f0a58..f875488 100644 --- a/src/trx.cpp +++ b/src/trx.cpp @@ -31,6 +31,7 @@ #include #include +#include // #define ZIP_DD_SIG 0x08074b50 // #define ZIP_CD_SIG 0x06054b50 @@ -59,13 +60,13 @@ std::string normalize_slashes(std::string path) { return path; } + bool parse_positions_dtype(const std::string &filename, std::string &out_dtype) { const std::string normalized = normalize_slashes(filename); try { - const auto tuple = trx::detail::_split_ext_with_dimensionality(normalized); - const std::string &base = std::get<0>(tuple); + auto [base, dim, dtype] = trx::detail::_split_ext_with_dimensionality(normalized); if (base == "positions") { - out_dtype = std::get<2>(tuple); + out_dtype = dtype; return true; } } catch (const std::exception &) { @@ -107,7 +108,7 @@ TrxSaveMode resolve_save_mode(const std::string &filename, TrxSaveMode requested std::array read_xyz_as_double(const TypedArray &positions, size_t row_index) { if (positions.cols != 3) { - throw std::runtime_error("Positions must have 3 columns."); + throw TrxFormatError("Positions must have 3 columns."); } if (row_index >= static_cast(positions.rows)) { throw std::out_of_range("Position row index out of range"); @@ -130,7 +131,7 @@ std::array read_xyz_as_double(const TypedArray &positions, size_t row view(static_cast(row_index), 1), view(static_cast(row_index), 2)}; } - throw std::runtime_error("Unsupported positions dtype for streamline extraction: " + positions.dtype); + throw TrxDTypeError("Unsupported positions dtype for streamline extraction: " + positions.dtype); } TypedArray make_typed_array(const std::string &filename, int rows, int cols, const std::string &dtype) { @@ -148,7 +149,7 @@ TypedArray make_typed_array(const std::string &filename, int rows, int cols, con std::string detect_positions_dtype(const std::string &path) { const trx::fs::path input(path); if (!trx::fs::exists(input)) { - throw std::runtime_error("Input path does not exist: " + path); + throw TrxIOError("Input path does not exist: " + path); } std::error_code ec; @@ -165,14 +166,14 @@ std::string detect_positions_dtype(const std::string &path) { } int err = 0; - zip_t *zf = open_zip_for_read(path, err); - if (zf == nullptr) { - throw std::runtime_error("Could not open zip file: " + path); + detail::ZipArchive zf(open_zip_for_read(path, err)); + if (!zf) { + throw TrxIOError("Could not open zip file: " + path); } std::string dtype; - const zip_int64_t count = zip_get_num_entries(zf, 0); + const zip_int64_t count = zip_get_num_entries(zf.get(), 0); for (zip_int64_t i = 0; i < count; ++i) { - const auto *name = zip_get_name(zf, i, 0); + const auto *name = zip_get_name(zf.get(), i, 0); if (name == nullptr) { continue; } @@ -180,7 +181,6 @@ std::string detect_positions_dtype(const std::string &path) { break; } } - zip_close(zf); return dtype; } @@ -201,7 +201,7 @@ TrxScalarType detect_positions_scalar_type(const std::string &path, TrxScalarTyp bool is_trx_directory(const std::string &path) { const trx::fs::path input(path); if (!trx::fs::exists(input)) { - throw std::runtime_error("Input path does not exist: " + path); + throw TrxIOError("Input path does not exist: " + path); } std::error_code ec; return trx::fs::is_directory(input, ec) && !ec; @@ -257,7 +257,7 @@ const TypedArray *AnyTrxFile::get_dpv(const std::string &name) const { std::vector> AnyTrxFile::get_streamline(size_t streamline_index) const { if (offsets_u64.empty()) { - throw std::runtime_error("TRX offsets are empty."); + throw TrxFormatError("TRX offsets are empty."); } const size_t n_streamlines = offsets_u64.size() - 1; if (streamline_index >= n_streamlines) { @@ -286,17 +286,7 @@ void AnyTrxFile::close() { _uncompressed_folder_handle.clear(); _owns_uncompressed_folder = false; - std::vector> affine(4, std::vector(4, 0.0f)); - for (int i = 0; i < 4; i++) { - affine[i][i] = 1.0f; - } - std::vector dimensions{1, 1, 1}; - json::object header_obj; - header_obj["VOXEL_TO_RASMM"] = affine; - header_obj["DIMENSIONS"] = dimensions; - header_obj["NB_VERTICES"] = 0; - header_obj["NB_STREAMLINES"] = 0; - header = json(header_obj); + header = default_header(); } void AnyTrxFile::_cleanup_temporary_directory() { @@ -311,7 +301,7 @@ void AnyTrxFile::_cleanup_temporary_directory() { AnyTrxFile AnyTrxFile::load(const std::string &path) { trx::fs::path input(path); if (!trx::fs::exists(input)) { - throw std::runtime_error("Input path does not exist: " + path); + throw TrxIOError("Input path does not exist: " + path); } std::error_code ec; if (trx::fs::is_directory(input, ec) && !ec) { @@ -322,13 +312,12 @@ AnyTrxFile AnyTrxFile::load(const std::string &path) { AnyTrxFile AnyTrxFile::load_from_zip(const std::string &filename) { int errorp = 0; - zip_t *zf = open_zip_for_read(filename, errorp); - if (zf == nullptr) { - throw std::runtime_error("Could not open zip file: " + filename); + detail::ZipArchive zf(open_zip_for_read(filename, errorp)); + if (!zf) { + throw TrxIOError("Could not open zip file: " + filename); } - std::string temp_dir = extract_zip_to_directory(zf); - zip_close(zf); + std::string temp_dir = extract_zip_to_directory(zf.get()); auto trx = AnyTrxFile::load_from_directory(temp_dir); trx._uncompressed_folder_handle = temp_dir; @@ -382,14 +371,14 @@ AnyTrxFile AnyTrxFile::load_from_directory(const std::string &path) { detail += "]"; } } - throw std::runtime_error(detail); + throw TrxIOError(detail); } std::string jstream((std::istreambuf_iterator(header_file)), std::istreambuf_iterator()); header_file.close(); std::string err; json header = json::parse(jstream, err); if (!err.empty()) { - throw std::runtime_error("Failed to parse header.json: " + err); + throw TrxIOError("Failed to parse header.json: " + err); } std::map> files_pointer_size; @@ -408,7 +397,7 @@ AnyTrxFile::_create_from_pointer(json header, trx.header = header; if (!header["NB_VERTICES"].is_number() || !header["NB_STREAMLINES"].is_number()) { - throw std::invalid_argument("Missing NB_VERTICES or NB_STREAMLINES in header.json"); + throw TrxFormatError("Missing NB_VERTICES or NB_STREAMLINES in header.json"); } const int nb_vertices = header["NB_VERTICES"].int_value(); @@ -417,28 +406,9 @@ AnyTrxFile::_create_from_pointer(json header, for (auto x = dict_pointer_size.rbegin(); x != dict_pointer_size.rend(); ++x) { const std::string elem_filename = x->first; - trx::fs::path elem_path(elem_filename); - trx::fs::path folder_path = elem_path.parent_path(); - std::string folder; - if (!root.empty()) { - trx::fs::path rel_path = elem_path.lexically_relative(trx::fs::path(root)); - std::string rel_str = rel_path.string(); - if (!rel_str.empty() && rel_str.rfind("..", 0) != 0) { - folder = rel_path.parent_path().string(); - } else { - folder = folder_path.string(); - } - } else { - folder = folder_path.string(); - } - if (folder == ".") { - folder.clear(); - } + std::string folder = folder_from_path(elem_filename, root); - std::tuple base_tuple = trx::detail::_split_ext_with_dimensionality(elem_filename); - std::string base(std::get<0>(base_tuple)); - int dim = std::get<1>(base_tuple); - std::string ext(std::get<2>(base_tuple)); + auto [base, dim, ext] = trx::detail::_split_ext_with_dimensionality(elem_filename); ext = _normalize_dtype(ext); @@ -446,35 +416,35 @@ AnyTrxFile::_create_from_pointer(json header, if (base == "positions" && (folder.empty() || folder == ".")) { if (size != static_cast(nb_vertices) * 3 || dim != 3) { - throw std::invalid_argument("Wrong positions size/dimensionality"); + throw TrxFormatError("Wrong positions size/dimensionality"); } if (ext != "float16" && ext != "float32" && ext != "float64") { - throw std::invalid_argument("Unsupported positions dtype: " + ext); + throw TrxDTypeError("Unsupported positions dtype: " + ext); } trx.positions = make_typed_array(elem_filename, nb_vertices, 3, ext); } else if (base == "offsets" && (folder.empty() || folder == ".")) { if (size != static_cast(nb_streamlines) + 1 || dim != 1) { - throw std::invalid_argument("Wrong offsets size/dimensionality"); + throw TrxFormatError("Wrong offsets size/dimensionality"); } if (ext != "uint32" && ext != "uint64") { - throw std::invalid_argument("Unsupported offsets dtype: " + ext); + throw TrxDTypeError("Unsupported offsets dtype: " + ext); } trx.offsets = make_typed_array(elem_filename, nb_streamlines + 1, 1, ext); } else if (folder == "dps") { const int nb_scalar = nb_streamlines > 0 ? static_cast(size / nb_streamlines) : 0; if (nb_streamlines == 0 || size % nb_streamlines != 0 || nb_scalar != dim) { - throw std::invalid_argument("Wrong dps size/dimensionality"); + throw TrxFormatError("Wrong dps size/dimensionality"); } trx.data_per_streamline.emplace(base, make_typed_array(elem_filename, nb_streamlines, nb_scalar, ext)); } else if (folder == "dpv") { const int nb_scalar = nb_vertices > 0 ? static_cast(size / nb_vertices) : 0; if (nb_vertices == 0 || size % nb_vertices != 0 || nb_scalar != dim) { - throw std::invalid_argument("Wrong dpv size/dimensionality"); + throw TrxFormatError("Wrong dpv size/dimensionality"); } trx.data_per_vertex.emplace(base, make_typed_array(elem_filename, nb_vertices, nb_scalar, ext)); } else if (folder.rfind("dpg", 0) == 0) { if (size != dim) { - throw std::invalid_argument("Wrong dpg size/dimensionality"); + throw TrxFormatError("Wrong dpg size/dimensionality"); } std::string data_name = path_basename(base); std::string sub_folder = path_basename(folder); @@ -482,19 +452,19 @@ AnyTrxFile::_create_from_pointer(json header, make_typed_array(elem_filename, 1, static_cast(size), ext)); } else if (folder == "groups") { if (dim != 1) { - throw std::invalid_argument("Wrong group dimensionality"); + throw TrxFormatError("Wrong group dimensionality"); } if (ext != "uint32") { - throw std::invalid_argument("Unsupported group dtype: " + ext); + throw TrxDTypeError("Unsupported group dtype: " + ext); } trx.groups.emplace(base, make_typed_array(elem_filename, static_cast(size), 1, ext)); } else { - throw std::invalid_argument("Entry is not part of a valid TRX structure: " + elem_filename); + throw TrxFormatError("Entry is not part of a valid TRX structure: " + elem_filename); } } if (trx.positions.empty() || trx.offsets.empty()) { - throw std::invalid_argument("Missing essential data."); + throw TrxFormatError("Missing essential data."); } const size_t offsets_count = trx.offsets.size(); @@ -512,7 +482,7 @@ AnyTrxFile::_create_from_pointer(json header, trx.offsets_u64[i] = static_cast(src[i]); } } else { - throw std::invalid_argument("Unsupported offsets datatype: " + trx.offsets.dtype); + throw TrxDTypeError("Unsupported offsets datatype: " + trx.offsets.dtype); } } @@ -521,7 +491,7 @@ AnyTrxFile::_create_from_pointer(json header, for (size_t i = 0; i + 1 < offsets_count; ++i) { const uint64_t diff = trx.offsets_u64[i + 1] - trx.offsets_u64[i]; if (diff > std::numeric_limits::max()) { - throw std::runtime_error("Offset difference exceeds uint32 range"); + throw TrxFormatError("Offset difference exceeds uint32 range"); } trx.lengths[i] = static_cast(diff); } @@ -540,94 +510,87 @@ void AnyTrxFile::save(const std::string &filename, const TrxSaveOptions &options const std::string ext = get_ext(filename); const TrxSaveMode save_mode = resolve_save_mode(filename, options.mode); if (ext.size() > 0 && ext != "zip" && ext != "trx") { - throw std::invalid_argument("Unsupported extension: " + ext); + throw TrxDTypeError("Unsupported extension: " + ext); } if (offsets.empty()) { - throw std::runtime_error("Cannot save TRX without offsets data"); + throw TrxFormatError("Cannot save TRX without offsets data"); } if (offsets_u64.empty()) { - throw std::runtime_error("Cannot save TRX without decoded offsets"); + throw TrxFormatError("Cannot save TRX without decoded offsets"); } if (header["NB_STREAMLINES"].is_number()) { const auto nb_streamlines = static_cast(header["NB_STREAMLINES"].int_value()); if (offsets_u64.size() != nb_streamlines + 1) { - throw std::runtime_error("TRX offsets size does not match NB_STREAMLINES"); + throw TrxFormatError("TRX offsets size does not match NB_STREAMLINES"); } } if (header["NB_VERTICES"].is_number()) { const auto nb_vertices = static_cast(header["NB_VERTICES"].int_value()); const auto last = offsets_u64.back(); if (last != nb_vertices) { - throw std::runtime_error("TRX offsets sentinel does not match NB_VERTICES"); + throw TrxFormatError("TRX offsets sentinel does not match NB_VERTICES"); } } for (size_t i = 1; i < offsets_u64.size(); ++i) { if (offsets_u64[i] < offsets_u64[i - 1]) { - throw std::runtime_error("TRX offsets must be monotonically increasing"); + throw TrxFormatError("TRX offsets must be monotonically increasing"); } } if (!positions.empty()) { const auto last = offsets_u64.back(); if (last != static_cast(positions.rows)) { - throw std::runtime_error("TRX positions row count does not match offsets sentinel"); + throw TrxFormatError("TRX positions row count does not match offsets sentinel"); } } const std::string source_dir = !_uncompressed_folder_handle.empty() ? _uncompressed_folder_handle : _backing_directory; if (source_dir.empty()) { - throw std::runtime_error("TRX file has no backing directory to save from"); + throw TrxIOError("TRX file has no backing directory to save from"); } if (save_mode == TrxSaveMode::Archive) { int errorp; - zip_t *zf; - if ((zf = zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)) == nullptr) { - throw std::runtime_error("Could not open archive " + filename + ": " + strerror(errorp)); + detail::ZipArchive zf(zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)); + if (!zf) { + throw TrxIOError("Could not open archive " + filename + ": " + strerror(errorp)); } const std::string header_payload = header.dump() + "\n"; zip_source_t *header_source = - zip_source_buffer(zf, header_payload.data(), header_payload.size(), 0 /* do not free */); + zip_source_buffer(zf.get(), header_payload.data(), header_payload.size(), 0 /* do not free */); if (header_source == nullptr) { - zip_close(zf); - throw std::runtime_error("Failed to create zip source for header.json: " + std::string(zip_strerror(zf))); + throw TrxIOError("Failed to create zip source for header.json: " + + std::string(zip_strerror(zf.get()))); } - const zip_int64_t header_idx = zip_file_add(zf, "header.json", header_source, ZIP_FL_ENC_UTF_8 | ZIP_FL_OVERWRITE); + const zip_int64_t header_idx = + zip_file_add(zf.get(), "header.json", header_source, ZIP_FL_ENC_UTF_8 | ZIP_FL_OVERWRITE); if (header_idx < 0) { - zip_source_free(header_source); - zip_close(zf); - throw std::runtime_error("Failed to add header.json to archive: " + std::string(zip_strerror(zf))); + throw TrxIOError("Failed to add header.json to archive: " + std::string(zip_strerror(zf.get()))); } const zip_int32_t compression = static_cast(options.compression_standard); - if (zip_set_file_compression(zf, header_idx, compression, 0) < 0) { - zip_close(zf); - throw std::runtime_error("Failed to set compression for header.json: " + std::string(zip_strerror(zf))); + if (zip_set_file_compression(zf.get(), header_idx, compression, 0) < 0) { + throw TrxIOError("Failed to set compression for header.json: " + + std::string(zip_strerror(zf.get()))); } const std::unordered_set skip = {"header.json"}; - zip_from_folder(zf, source_dir, source_dir, options.compression_standard, &skip); - if (zip_close(zf) != 0) { - throw std::runtime_error("Unable to close archive " + filename + ": " + zip_strerror(zf)); - } + zip_from_folder(zf.get(), source_dir, source_dir, options.compression_standard, &skip); + zf.commit(filename); } else { std::error_code ec; if (trx::fs::exists(filename, ec) && trx::fs::is_directory(filename, ec)) { if (!options.overwrite_existing) { - throw std::runtime_error("Output directory already exists: " + filename); + throw TrxIOError("Output directory already exists: " + filename); } if (rm_dir(filename) != 0) { - throw std::runtime_error("Could not remove existing directory " + filename); + throw TrxIOError("Could not remove existing directory " + filename); } } trx::fs::path dest_path(filename); if (dest_path.has_parent_path()) { - std::error_code parent_ec; - trx::fs::create_directories(dest_path.parent_path(), parent_ec); - if (parent_ec) { - throw std::runtime_error("Could not create output parent directory: " + dest_path.parent_path().string()); - } + mkdir_or_throw(dest_path.parent_path().string()); } std::error_code source_ec; const trx::fs::path source_path = trx::fs::weakly_canonical(trx::fs::path(source_dir), source_ec); @@ -642,17 +605,17 @@ void AnyTrxFile::save(const std::string &filename, const TrxSaveOptions &options const trx::fs::path final_header_path = dest_path / "header.json"; std::ofstream out_json(final_header_path, std::ios::out | std::ios::trunc); if (!out_json.is_open()) { - throw std::runtime_error("Failed to write header.json to: " + final_header_path.string()); + throw TrxIOError("Failed to write header.json to: " + final_header_path.string()); } out_json << header.dump() << std::endl; out_json.close(); ec.clear(); if (!trx::fs::exists(filename, ec) || !trx::fs::is_directory(filename, ec)) { - throw std::runtime_error("Failed to create output directory: " + filename); + throw TrxIOError("Failed to create output directory: " + filename); } if (!trx::fs::exists(final_header_path)) { - throw std::runtime_error("Missing header.json in output directory: " + final_header_path.string()); + throw TrxFormatError("Missing header.json in output directory: " + final_header_path.string()); } } } @@ -666,7 +629,7 @@ void populate_fps(const string &name, std::mappath(); const std::string filename = entry_path.filename().string(); @@ -696,14 +659,14 @@ void populate_fps(const string &name, std::map(dtype_size) == 0) { @@ -712,7 +675,7 @@ void populate_fps(const string &name, std::map 0) { + while ((nbytes = zip_fread(zh.get(), buffer.data(), buff_len - 1)) > 0) { jstream.append(buffer.data(), static_cast(nbytes)); } - zip_fclose(zh); - // convert jstream data into Json. std::string err; auto root = json::parse(jstream, err); if (!err.empty()) { - throw std::runtime_error("Failed to parse header.json: " + err); + throw TrxIOError("Failed to parse header.json: " + err); } return root; } @@ -808,7 +769,7 @@ mio::shared_mmap_sink _create_memmap(std::string filename, return rw_mmap; } -// TODO: support FORTRAN ORDERING +// Known limitation: only C (row-major) ordering is supported; Fortran ordering is not. // template json assignHeader(const json &root) { @@ -842,16 +803,14 @@ void get_reference_info( const Eigen::RowVectorXi &dimensions) { // NOLINT(misc-use-internal-linkage,misc-include-cleaner) static_cast(affine); static_cast(dimensions); - // TODO: find a library to use for nifti and trk (MRtrix??) + // Known limitation: NIfTI support is partially addressed by nifti_io.cpp; TRK is not yet supported. // if (reference.find(".nii") != std::string::npos) // { // } if (reference.find(".trk") != std::string::npos) { - // TODO: Create exception class - throw std::runtime_error("Trk reference not implemented"); + throw TrxError("Trk reference not implemented"); } - // TODO: Create exception class - throw std::runtime_error("Trk reference not implemented"); + throw TrxError("Trk reference not implemented"); } void copy_dir(const string &src, const string &dst) { @@ -862,16 +821,14 @@ void copy_dir(const string &src, const string &dst) { return; } - if (!trx::fs::create_directories(dst_path, ec) && ec) { - throw std::runtime_error(std::string("Could not create directory ") + dst); - } + mkdir_or_throw(dst); ec.clear(); const auto options = trx::fs::copy_options::recursive | trx::fs::copy_options::overwrite_existing | trx::fs::copy_options::skip_symlinks; trx::fs::copy(src_path, dst_path, options, ec); if (ec) { - throw std::runtime_error("Failed to copy directory: " + ec.message()); + throw TrxIOError("Failed to copy directory: " + ec.message()); } } @@ -879,7 +836,7 @@ void copy_file(const string &src, const string &dst) { std::error_code ec; trx::fs::copy_file(src, dst, trx::fs::copy_options::overwrite_existing, ec); if (ec) { - throw std::runtime_error(std::string("Failed to copy file ") + src + ": " + ec.message()); + throw TrxIOError(std::string("Failed to copy file ") + src + ": " + ec.message()); } } int rm_dir(const string &d) { @@ -936,11 +893,7 @@ std::string make_temp_dir(const std::string &prefix) { const trx::fs::path base_path(base_dir); std::error_code ec; if (!trx::fs::exists(base_path, ec)) { - ec.clear(); - trx::fs::create_directories(base_path, ec); - if (ec) { - throw std::runtime_error("Failed to create base temp directory: " + base_dir); - } + mkdir_or_throw(base_path.string()); } static std::mt19937_64 rng(std::random_device{}()); @@ -959,15 +912,15 @@ std::string make_temp_dir(const std::string &prefix) { return candidate.string(); } if (ec && ec != std::errc::file_exists) { - throw std::runtime_error("Failed to create temporary directory: " + ec.message()); + throw TrxIOError("Failed to create temporary directory: " + ec.message()); } } - throw std::runtime_error("Failed to create temporary directory"); + throw TrxIOError("Failed to create temporary directory"); } std::string extract_zip_to_directory(zip_t *zfolder) { if (zfolder == nullptr) { - throw std::invalid_argument("Zip archive pointer is null"); + throw TrxArgumentError("Zip archive pointer is null"); } const std::string root_dir = make_temp_dir("trx_zip"); const trx::fs::path normalized_root = trx::fs::path(root_dir).lexically_normal(); @@ -982,7 +935,7 @@ std::string extract_zip_to_directory(zip_t *zfolder) { const trx::fs::path entry_path(entry); if (entry_path.is_absolute()) { - throw std::runtime_error("Zip entry has absolute path: " + entry); + throw TrxIOError("Zip entry has absolute path: " + entry); } const trx::fs::path normalized_entry = entry_path.lexically_normal(); @@ -990,60 +943,42 @@ std::string extract_zip_to_directory(zip_t *zfolder) { const trx::fs::path normalized_out = out_path.lexically_normal(); if (!is_path_within(normalized_out, normalized_root)) { - throw std::runtime_error("Zip entry escapes temporary directory: " + entry); + throw TrxIOError("Zip entry escapes temporary directory: " + entry); } if (!entry.empty() && entry.back() == '/') { - std::error_code ec; - trx::fs::create_directories(normalized_out, ec); - if (ec) { - throw std::runtime_error("Failed to create directory: " + normalized_out.string()); - } + mkdir_or_throw(normalized_out.string()); continue; } - std::error_code ec; - trx::fs::create_directories(normalized_out.parent_path(), ec); - if (ec) { - throw std::runtime_error("Failed to create parent directory: " + normalized_out.parent_path().string()); - } + mkdir_or_throw(normalized_out.parent_path().string()); - zip_file_t *zf = zip_fopen_index(zfolder, i, ZIP_FL_UNCHANGED); - if (zf == nullptr) { - throw std::runtime_error("Failed to open zip entry: " + entry); + detail::ZipFile zf(zip_fopen_index(zfolder, i, ZIP_FL_UNCHANGED)); + if (!zf) { + throw TrxIOError("Failed to open zip entry: " + entry); } std::ofstream out(normalized_out.string(), std::ios::binary); if (!out.is_open()) { - zip_fclose(zf); - throw std::runtime_error("Failed to open output file: " + normalized_out.string()); + throw TrxIOError("Failed to open output file: " + normalized_out.string()); } std::array buffer{}; zip_int64_t nbytes = 0; - while ((nbytes = zip_fread(zf, buffer.data(), buffer.size())) > 0) { + while ((nbytes = zip_fread(zf.get(), buffer.data(), buffer.size())) > 0) { out.write(buffer.data(), nbytes); if (!out) { - out.close(); - zip_fclose(zf); - throw std::runtime_error("Failed to write to output file: " + normalized_out.string()); + throw TrxIOError("Failed to write to output file: " + normalized_out.string()); } } if (nbytes < 0) { - out.close(); - zip_fclose(zf); - throw std::runtime_error("Failed to read data from zip entry: " + entry); + throw TrxIOError("Failed to read data from zip entry: " + entry); } out.flush(); if (!out) { - out.close(); - zip_fclose(zf); - throw std::runtime_error("Failed to flush output file: " + normalized_out.string()); + throw TrxIOError("Failed to flush output file: " + normalized_out.string()); } - - out.close(); - zip_fclose(zf); } return root_dir; @@ -1057,7 +992,7 @@ void zip_from_folder(zip_t *zf, std::error_code ec; for (trx::fs::recursive_directory_iterator it(directory, ec), end; it != end; it.increment(ec)) { if (ec) { - throw std::runtime_error("Failed to read directory: " + directory); + throw TrxIOError("Failed to read directory: " + directory); } const trx::fs::path current = it->path(); const std::string zip_fname = rm_root(root, current.string()); @@ -1073,7 +1008,7 @@ void zip_from_folder(zip_t *zf, const std::string fullpath = current.string(); zip_source_t *source = zip_source_file(zf, fullpath.c_str(), 0, 0); if (source == nullptr) { - throw std::runtime_error(std::string("Error adding file ") + zip_fname + ": " + zip_strerror(zf)); + throw TrxIOError(std::string("Error adding file ") + zip_fname + ": " + zip_strerror(zf)); } if (skip && skip->find(zip_fname) != skip->end()) { zip_source_free(source); @@ -1082,11 +1017,11 @@ void zip_from_folder(zip_t *zf, const zip_int64_t file_idx = zip_file_add(zf, zip_fname.c_str(), source, ZIP_FL_ENC_UTF_8); if (file_idx < 0) { zip_source_free(source); - throw std::runtime_error(std::string("Error adding file ") + zip_fname + ": " + zip_strerror(zf)); + throw TrxIOError(std::string("Error adding file ") + zip_fname + ": " + zip_strerror(zf)); } const zip_int32_t compression = static_cast(compression_standard); if (zip_set_file_compression(zf, file_idx, compression, 0) < 0) { - throw std::runtime_error(std::string("Error setting compression for ") + zip_fname + ": " + zip_strerror(zf)); + throw TrxIOError(std::string("Error setting compression for ") + zip_fname + ": " + zip_strerror(zf)); } } } @@ -1126,7 +1061,7 @@ void write_typed_array_file(const std::string &path, const TypedArray &arr) { const auto bytes = arr.to_bytes(); std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to open output file: " + path); + throw TrxIOError("Failed to open output file: " + path); } if (bytes.data && bytes.size > 0) { out.write(reinterpret_cast(bytes.data), static_cast(bytes.size)); @@ -1138,10 +1073,10 @@ void write_typed_array_file(const std::string &path, const TypedArray &arr) { void AnyTrxFile::for_each_positions_chunk(size_t chunk_bytes, const PositionsChunkCallback &fn) const { if (positions.empty()) { - throw std::runtime_error("TRX positions are empty."); + throw TrxFormatError("TRX positions are empty."); } if (positions.cols != 3) { - throw std::runtime_error("Positions must have 3 columns."); + throw TrxFormatError("Positions must have 3 columns."); } if (!fn) { return; @@ -1167,10 +1102,10 @@ void AnyTrxFile::for_each_positions_chunk(size_t chunk_bytes, const PositionsChu void AnyTrxFile::for_each_positions_chunk_mutable(size_t chunk_bytes, const PositionsChunkMutableCallback &fn) { if (positions.empty()) { - throw std::runtime_error("TRX positions are empty."); + throw TrxFormatError("TRX positions are empty."); } if (positions.cols != 3) { - throw std::runtime_error("Positions must have 3 columns."); + throw TrxFormatError("Positions must have 3 columns."); } if (!fn) { return; @@ -1198,10 +1133,10 @@ PositionsOutputInfo prepare_positions_output(const AnyTrxFile &input, const std::string &output_directory, const PrepareOutputOptions &options) { if (input.positions.empty() || input.offsets.empty()) { - throw std::runtime_error("Input TRX missing positions/offsets."); + throw TrxFormatError("Input TRX missing positions/offsets."); } if (input.positions.cols != 3) { - throw std::runtime_error("Positions must have 3 columns."); + throw TrxFormatError("Positions must have 3 columns."); } std::error_code ec; @@ -1209,20 +1144,16 @@ PositionsOutputInfo prepare_positions_output(const AnyTrxFile &input, if (options.overwrite_existing) { trx::fs::remove_all(output_directory, ec); } else { - throw std::runtime_error("Output directory already exists: " + output_directory); + throw TrxIOError("Output directory already exists: " + output_directory); } } - ec.clear(); - trx::fs::create_directories(output_directory, ec); - if (ec) { - throw std::runtime_error("Failed to create output directory: " + output_directory); - } + mkdir_or_throw(output_directory); const std::string header_path = output_directory + SEPARATOR + "header.json"; { std::ofstream out(header_path, std::ios::out | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to write header.json to: " + header_path); + throw TrxIOError("Failed to write header.json to: " + header_path); } out << input.header.dump() << std::endl; } @@ -1275,21 +1206,21 @@ PositionsOutputInfo prepare_positions_output(const AnyTrxFile &input, void merge_trx_shards(const MergeTrxShardsOptions &options) { if (options.shard_directories.empty()) { - throw std::invalid_argument("merge_trx_shards requires at least one shard directory"); + throw TrxArgumentError("merge_trx_shards requires at least one shard directory"); } auto read_header = [](const std::string &dir) { const std::string path = dir + SEPARATOR + "header.json"; std::ifstream in(path); if (!in.is_open()) { - throw std::runtime_error("Failed to open shard header: " + path); + throw TrxIOError("Failed to open shard header: " + path); } std::stringstream ss; ss << in.rdbuf(); std::string err; json parsed = json::parse(ss.str(), err); if (!err.empty()) { - throw std::runtime_error("Failed to parse shard header " + path + ": " + err); + throw TrxIOError("Failed to parse shard header " + path + ": " + err); } return parsed; }; @@ -1314,11 +1245,11 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { auto append_binary = [](const std::string &dst, const std::string &src) { std::ifstream in(src, std::ios::binary); if (!in.is_open()) { - throw std::runtime_error("Failed to open source for append: " + src); + throw TrxIOError("Failed to open source for append: " + src); } std::ofstream out(dst, std::ios::binary | std::ios::app); if (!out.is_open()) { - throw std::runtime_error("Failed to open destination for append: " + dst); + throw TrxIOError("Failed to open destination for append: " + dst); } std::vector buffer(1 << 20); while (in) { @@ -1333,11 +1264,11 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { auto append_offsets_with_base = [](const std::string &dst, const std::string &src, uint64_t base_vertices, bool skip_first) { std::ifstream in(src, std::ios::binary); if (!in.is_open()) { - throw std::runtime_error("Failed to open source offsets: " + src); + throw TrxIOError("Failed to open source offsets: " + src); } std::ofstream out(dst, std::ios::binary | std::ios::app); if (!out.is_open()) { - throw std::runtime_error("Failed to open destination offsets: " + dst); + throw TrxIOError("Failed to open destination offsets: " + dst); } constexpr size_t kChunkElems = (8 * 1024 * 1024) / sizeof(uint64_t); std::vector buffer(kChunkElems); @@ -1350,7 +1281,7 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { break; } if (bytes % static_cast(sizeof(uint64_t)) != 0) { - throw std::runtime_error("Offsets file has invalid byte count: " + src); + throw TrxFormatError("Offsets file has invalid byte count: " + src); } const size_t count = static_cast(bytes) / sizeof(uint64_t); size_t start_index = 0; @@ -1374,11 +1305,11 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { auto append_group_indices_with_base = [](const std::string &dst, const std::string &src, uint32_t base_streamlines) { std::ifstream in(src, std::ios::binary); if (!in.is_open()) { - throw std::runtime_error("Failed to open source group file: " + src); + throw TrxIOError("Failed to open source group file: " + src); } std::ofstream out(dst, std::ios::binary | std::ios::app); if (!out.is_open()) { - throw std::runtime_error("Failed to open destination group file: " + dst); + throw TrxIOError("Failed to open destination group file: " + dst); } constexpr size_t kChunkElems = (8 * 1024 * 1024) / sizeof(uint32_t); std::vector buffer(kChunkElems); @@ -1390,7 +1321,7 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { break; } if (bytes % static_cast(sizeof(uint32_t)) != 0) { - throw std::runtime_error("Group file has invalid byte count: " + src); + throw TrxFormatError("Group file has invalid byte count: " + src); } const size_t count = static_cast(bytes) / sizeof(uint32_t); for (size_t i = 0; i < count; ++i) { @@ -1408,11 +1339,11 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { return files; } if (!trx::fs::is_directory(path, ec)) { - throw std::runtime_error("Expected directory for subdir: " + path.string()); + throw TrxFormatError("Expected directory for subdir: " + path.string()); } for (trx::fs::directory_iterator it(path, ec), end; it != end; it.increment(ec)) { if (ec) { - throw std::runtime_error("Failed to read directory: " + path.string()); + throw TrxIOError("Failed to read directory: " + path.string()); } if (!it->is_regular_file()) { continue; @@ -1426,32 +1357,29 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { auto ensure_schema_match = [&](const std::string &subdir, const std::vector &schema_files, const std::string &shard) { const auto shard_files = list_subdir_files(shard, subdir); if (shard_files != schema_files) { - throw std::runtime_error("Shard schema mismatch for subdir '" + subdir + "': " + shard); + throw TrxFormatError("Shard schema mismatch for subdir '" + subdir + "': " + shard); } }; std::error_code ec; for (const auto &dir : options.shard_directories) { if (!trx::fs::exists(dir, ec) || !trx::fs::is_directory(dir, ec)) { - throw std::runtime_error("Shard directory does not exist: " + dir); + throw TrxFormatError("Shard directory does not exist: " + dir); } } const std::string output_dir = options.output_directory ? options.output_path : make_temp_dir("trx_merge"); if (trx::fs::exists(output_dir, ec)) { if (!options.overwrite_existing) { - throw std::runtime_error("Output already exists: " + output_dir); + throw TrxIOError("Output already exists: " + output_dir); } trx::fs::remove_all(output_dir, ec); } - trx::fs::create_directories(output_dir, ec); - if (ec) { - throw std::runtime_error("Failed to create output directory: " + output_dir); - } + mkdir_or_throw(output_dir); for (const auto &dir : options.shard_directories) { if (trx::fs::exists(dir + SEPARATOR + "dpg", ec)) { - throw std::runtime_error("merge_trx_shards currently does not support dpg/ merges"); + throw TrxArgumentError("merge_trx_shards currently does not support dpg/ merges"); } } @@ -1459,10 +1387,10 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { const std::string first_positions = find_file_with_prefix(options.shard_directories.front(), "positions."); const std::string first_offsets = find_file_with_prefix(options.shard_directories.front(), "offsets."); if (first_positions.empty() || first_offsets.empty()) { - throw std::runtime_error("Shard missing positions/offsets: " + options.shard_directories.front()); + throw TrxFormatError("Shard missing positions/offsets: " + options.shard_directories.front()); } if (get_ext(first_offsets) != "uint64") { - throw std::runtime_error("merge_trx_shards currently requires offsets.uint64"); + throw TrxArgumentError("merge_trx_shards currently requires offsets.uint64"); } const std::string positions_filename = trx::fs::path(first_positions).filename().string(); @@ -1472,13 +1400,13 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { { std::ofstream clear_positions(positions_out, std::ios::binary | std::ios::out | std::ios::trunc); if (!clear_positions.is_open()) { - throw std::runtime_error("Failed to create output positions file: " + positions_out); + throw TrxIOError("Failed to create output positions file: " + positions_out); } } { std::ofstream clear_offsets(offsets_out, std::ios::binary | std::ios::out | std::ios::trunc); if (!clear_offsets.is_open()) { - throw std::runtime_error("Failed to create output offsets file: " + offsets_out); + throw TrxIOError("Failed to create output offsets file: " + offsets_out); } } @@ -1497,20 +1425,20 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { for (const auto &name : dps_schema) { std::ofstream clear_file(output_dir + SEPARATOR + "dps" + SEPARATOR + name, std::ios::binary | std::ios::out | std::ios::trunc); if (!clear_file.is_open()) { - throw std::runtime_error("Failed to create merged dps file: " + name); + throw TrxIOError("Failed to create merged dps file: " + name); } } for (const auto &name : dpv_schema) { std::ofstream clear_file(output_dir + SEPARATOR + "dpv" + SEPARATOR + name, std::ios::binary | std::ios::out | std::ios::trunc); if (!clear_file.is_open()) { - throw std::runtime_error("Failed to create merged dpv file: " + name); + throw TrxIOError("Failed to create merged dpv file: " + name); } } for (const auto &name : groups_schema) { std::ofstream clear_file(output_dir + SEPARATOR + "groups" + SEPARATOR + name, std::ios::binary | std::ios::out | std::ios::trunc); if (!clear_file.is_open()) { - throw std::runtime_error("Failed to create merged group file: " + name); + throw TrxIOError("Failed to create merged group file: " + name); } } @@ -1529,13 +1457,13 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { const std::string shard_positions = find_file_with_prefix(shard_dir, "positions."); const std::string shard_offsets = find_file_with_prefix(shard_dir, "offsets."); if (shard_positions.empty() || shard_offsets.empty()) { - throw std::runtime_error("Shard missing positions/offsets: " + shard_dir); + throw TrxFormatError("Shard missing positions/offsets: " + shard_dir); } if (trx::fs::path(shard_positions).filename().string() != positions_filename) { - throw std::runtime_error("Shard positions dtype mismatch: " + shard_dir); + throw TrxFormatError("Shard positions dtype mismatch: " + shard_dir); } if (trx::fs::path(shard_offsets).filename().string() != offsets_filename) { - throw std::runtime_error("Shard offsets dtype mismatch: " + shard_dir); + throw TrxFormatError("Shard offsets dtype mismatch: " + shard_dir); } append_binary(positions_out, shard_positions); @@ -1549,7 +1477,7 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { } for (const auto &name : groups_schema) { if (total_streamlines > static_cast(std::numeric_limits::max())) { - throw std::runtime_error("Group index offset exceeds uint32 range during merge"); + throw TrxFormatError("Group index offset exceeds uint32 range during merge"); } append_group_indices_with_base( output_dir + SEPARATOR + "groups" + SEPARATOR + name, @@ -1567,7 +1495,7 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { const std::string merged_header_path = output_dir + SEPARATOR + "header.json"; std::ofstream out(merged_header_path, std::ios::out | std::ios::trunc); if (!out.is_open()) { - throw std::runtime_error("Failed to write merged header: " + merged_header_path); + throw TrxIOError("Failed to write merged header: " + merged_header_path); } out << merged_header.dump() << std::endl; } @@ -1578,22 +1506,16 @@ void merge_trx_shards(const MergeTrxShardsOptions &options) { const trx::fs::path archive_path(options.output_path); if (archive_path.has_parent_path()) { - std::error_code parent_ec; - trx::fs::create_directories(archive_path.parent_path(), parent_ec); - if (parent_ec) { - throw std::runtime_error("Could not create archive parent directory: " + archive_path.parent_path().string()); - } + mkdir_or_throw(archive_path.parent_path().string()); } int errorp = 0; - zip_t *zf = zip_open(options.output_path.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp); - if (zf == nullptr) { - throw std::runtime_error("Could not open archive " + options.output_path + ": " + strerror(errorp)); - } - zip_from_folder(zf, output_dir, output_dir, options.compression_standard, nullptr); - if (zip_close(zf) != 0) { - throw std::runtime_error("Unable to close archive " + options.output_path + ": " + zip_strerror(zf)); + detail::ZipArchive zf(zip_open(options.output_path.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)); + if (!zf) { + throw TrxIOError("Could not open archive " + options.output_path + ": " + strerror(errorp)); } + zip_from_folder(zf.get(), output_dir, output_dir, options.compression_standard, nullptr); + zf.commit(options.output_path); rm_dir(output_dir); } }; // namespace trx \ No newline at end of file diff --git a/tests/test_trx_anytrxfile.cpp b/tests/test_trx_anytrxfile.cpp index 9803d53..a9d51f7 100644 --- a/tests/test_trx_anytrxfile.cpp +++ b/tests/test_trx_anytrxfile.cpp @@ -1,8 +1,6 @@ #include #include -#define private public #include -#undef private #include #include @@ -419,7 +417,7 @@ TEST(AnyTrxFile, MissingHeaderCountsThrows) { header_obj.erase("NB_VERTICES"); write_header_file(corrupt_dir, json(header_obj)); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -433,7 +431,7 @@ TEST(AnyTrxFile, WrongPositionsDimThrows) { const fs::path positions = find_file_with_prefix(corrupt_dir, "positions"); rename_with_new_dim(positions, 4); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -448,7 +446,7 @@ TEST(AnyTrxFile, UnsupportedPositionsDtypeThrows) { const std::string ext = get_ext(positions.string()); rename_with_new_ext(positions, pick_int_dtype_same_size(ext)); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxDTypeError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -462,7 +460,7 @@ TEST(AnyTrxFile, WrongOffsetsDimThrows) { const fs::path offsets = find_file_with_prefix(corrupt_dir, "offsets"); rename_with_new_dim(offsets, 2); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -477,7 +475,7 @@ TEST(AnyTrxFile, UnsupportedOffsetsDtypeThrows) { const std::string ext = get_ext(offsets.string()); rename_with_new_ext(offsets, pick_int_dtype_same_size(ext)); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxDTypeError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -499,7 +497,7 @@ TEST(AnyTrxFile, WrongDpsDimThrows) { const fs::path dps_file = find_first_file_recursive(dps_dir); rename_with_new_dim(dps_file, 2); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -521,7 +519,7 @@ TEST(AnyTrxFile, WrongDpvDimThrows) { const fs::path dpv_file = find_first_file_recursive(dpv_dir); rename_with_new_dim(dpv_file, 2); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -542,7 +540,7 @@ TEST(AnyTrxFile, WrongDpgDimThrows) { const fs::path dpg_file = find_first_file_recursive(dpg_dir); rename_with_new_dim(dpg_file, 2); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -564,7 +562,7 @@ TEST(AnyTrxFile, UnsupportedGroupDtypeThrows) { const fs::path group_file = find_first_file_recursive(groups_dir); rename_with_new_ext(group_file, "int32"); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxDTypeError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -583,7 +581,7 @@ TEST(AnyTrxFile, InvalidEntryThrows) { out.write(value_bytes.data(), static_cast(value_bytes.size())); out.close(); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -598,7 +596,7 @@ TEST(AnyTrxFile, MissingEssentialDataThrows) { std::error_code ec; fs::remove(positions, ec); - EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); fs::remove_all(temp_root, ec); } @@ -631,7 +629,7 @@ TEST(AnyTrxFile, OffsetsOverflowThrows) { out.write(data_bytes.data(), static_cast(data_bytes.size())); out.close(); - EXPECT_THROW(load_any(corrupt_dir.string()), std::runtime_error); + EXPECT_THROW(load_any(corrupt_dir.string()), trx::TrxFormatError); std::error_code ec; fs::remove_all(temp_root, ec); @@ -644,7 +642,7 @@ TEST(AnyTrxFile, SaveRejectsUnsupportedExtension) { const auto temp_dir = make_temp_test_dir("trx_any_save_badext"); const fs::path out_path = temp_dir / "bad.txt"; - EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::invalid_argument); + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), trx::TrxDTypeError); trx.close(); std::error_code ec; @@ -659,7 +657,7 @@ TEST(AnyTrxFile, SaveRejectsMissingOffsets) { trx.offsets = TypedArray(); const auto temp_dir = make_temp_test_dir("trx_any_save_no_offsets"); const fs::path out_path = temp_dir / "missing_offsets.trx"; - EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), trx::TrxFormatError); trx.close(); std::error_code ec; @@ -674,7 +672,7 @@ TEST(AnyTrxFile, SaveRejectsMissingDecodedOffsets) { trx.offsets_u64.clear(); const auto temp_dir = make_temp_test_dir("trx_any_save_no_offsets_u64"); const fs::path out_path = temp_dir / "missing_offsets_u64.trx"; - EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), trx::TrxFormatError); trx.close(); std::error_code ec; @@ -692,7 +690,7 @@ TEST(AnyTrxFile, SaveRejectsStreamlineCountMismatch) { const auto temp_dir = make_temp_test_dir("trx_any_save_bad_streamlines"); const fs::path out_path = temp_dir / "bad_streamlines.trx"; - EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), trx::TrxFormatError); trx.close(); std::error_code ec; @@ -710,7 +708,7 @@ TEST(AnyTrxFile, SaveRejectsVertexCountMismatch) { const auto temp_dir = make_temp_test_dir("trx_any_save_bad_vertices"); const fs::path out_path = temp_dir / "bad_vertices.trx"; - EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), trx::TrxFormatError); trx.close(); std::error_code ec; @@ -728,7 +726,7 @@ TEST(AnyTrxFile, SaveRejectsNonMonotonicOffsets) { const auto temp_dir = make_temp_test_dir("trx_any_save_non_mono"); const fs::path out_path = temp_dir / "non_mono.trx"; - EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), trx::TrxFormatError); trx.close(); std::error_code ec; @@ -746,7 +744,7 @@ TEST(AnyTrxFile, SaveRejectsPositionsRowMismatch) { const auto temp_dir = make_temp_test_dir("trx_any_save_bad_positions"); const fs::path out_path = temp_dir / "bad_positions.trx"; - EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), trx::TrxFormatError); trx.close(); std::error_code ec; @@ -758,12 +756,12 @@ TEST(AnyTrxFile, SaveRejectsMissingBackingDirectory) { const fs::path gs_trx = gs_dir / "gs_fldr.trx"; auto trx = load_any(gs_trx.string()); - trx._backing_directory.clear(); - trx._uncompressed_folder_handle.clear(); + trx.backing_directory().clear(); + trx.uncompressed_folder_handle().clear(); const auto temp_dir = make_temp_test_dir("trx_any_save_no_backing"); const fs::path out_path = temp_dir / "no_backing.trx"; - EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), trx::TrxIOError); trx.close(); std::error_code ec; @@ -858,7 +856,7 @@ TEST(AnyTrxFile, MergeTrxShardsSchemaMismatchThrows) { options.shard_directories = {shard1.string(), shard2.string()}; options.output_path = output_dir.string(); options.output_directory = true; - EXPECT_THROW(merge_trx_shards(options), std::runtime_error); + EXPECT_THROW(merge_trx_shards(options), trx::TrxFormatError); fs::remove_all(temp_root, ec); } @@ -999,7 +997,7 @@ TEST(AnyTrxFile, MergeTrxShardsRejectsDpg) { options.shard_directories = {shard1.string(), shard2.string()}; options.output_path = (temp_root / "merged").string(); options.output_directory = true; - EXPECT_THROW(merge_trx_shards(options), std::runtime_error); + EXPECT_THROW(merge_trx_shards(options), trx::TrxArgumentError); fs::remove_all(temp_root, ec); } @@ -1021,7 +1019,7 @@ TEST(AnyTrxFile, PreparePositionsOutputOverwriteFalseThrows) { PrepareOutputOptions options; options.overwrite_existing = false; - EXPECT_THROW(prepare_positions_output(input, output_dir.string(), options), std::runtime_error); + EXPECT_THROW(prepare_positions_output(input, output_dir.string(), options), trx::TrxIOError); input.close(); fs::remove_all(temp_root, ec); diff --git a/tests/test_trx_io.cpp b/tests/test_trx_io.cpp index 46c492a..2c5bac0 100644 --- a/tests/test_trx_io.cpp +++ b/tests/test_trx_io.cpp @@ -711,20 +711,20 @@ TEST(TrxFileIo, add_dps_from_text_errors) { const fs::path input_path = tmp_dir / "dps.txt"; write_text_file(input_path, "1.0"); - EXPECT_THROW(trx.add_dps_from_text("", "float32", input_path.string()), std::invalid_argument); - EXPECT_THROW(trx.add_dps_from_text("weight", "badtype", input_path.string()), std::invalid_argument); - EXPECT_THROW(trx.add_dps_from_text("weight", "int32", input_path.string()), std::invalid_argument); + EXPECT_THROW(trx.add_dps_from_text("", "float32", input_path.string()), trx::TrxArgumentError); + EXPECT_THROW(trx.add_dps_from_text("weight", "badtype", input_path.string()), trx::TrxDTypeError); + EXPECT_THROW(trx.add_dps_from_text("weight", "int32", input_path.string()), trx::TrxDTypeError); - EXPECT_THROW(trx.add_dps_from_text("weight", "float32", (tmp_dir / "missing.txt").string()), std::runtime_error); + EXPECT_THROW(trx.add_dps_from_text("weight", "float32", (tmp_dir / "missing.txt").string()), trx::TrxIOError); write_text_file(input_path, "1.0 abc"); - EXPECT_THROW(trx.add_dps_from_text("weight", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dps_from_text("weight", "float32", input_path.string()), trx::TrxFormatError); write_text_file(input_path, "1.0"); - EXPECT_THROW(trx.add_dps_from_text("weight", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dps_from_text("weight", "float32", input_path.string()), trx::TrxFormatError); trx::TrxFile empty; - EXPECT_THROW(empty.add_dps_from_text("weight", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(empty.add_dps_from_text("weight", "float32", input_path.string()), trx::TrxIOError); } TEST(TrxFileIo, add_dpv_from_tsf_success) { @@ -763,34 +763,34 @@ TEST(TrxFileIo, add_dpv_from_tsf_errors) { const fs::path input_path = tmp_dir / "dpv.tsf"; write_tsf_text_file(input_path, build_tsf_contents({{0.1, 0.2}, {0.3}})); - EXPECT_THROW(trx.add_dpv_from_tsf("", "float32", input_path.string()), std::invalid_argument); - EXPECT_THROW(trx.add_dpv_from_tsf("signal", "badtype", input_path.string()), std::invalid_argument); - EXPECT_THROW(trx.add_dpv_from_tsf("signal", "int32", input_path.string()), std::invalid_argument); + EXPECT_THROW(trx.add_dpv_from_tsf("", "float32", input_path.string()), trx::TrxArgumentError); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "badtype", input_path.string()), trx::TrxDTypeError); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "int32", input_path.string()), trx::TrxDTypeError); - EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", (tmp_dir / "missing.tsf").string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", (tmp_dir / "missing.tsf").string()), trx::TrxIOError); write_tsf_text_file(input_path, "0.1 0.2 abc"); - EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), trx::TrxFormatError); write_tsf_text_file(input_path, build_tsf_contents({{0.1}, {0.2, 0.3}})); - EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), trx::TrxFormatError); write_tsf_text_file(input_path, build_tsf_contents({{0.1, 0.2}, {0.3}})); - EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), trx::TrxFormatError); write_text_file(input_path, "mrtrix track scalars\nfile: . 0\ndatatype: Float32LE\ntimestamp: 0\n0.1 0.2"); - EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), trx::TrxFormatError); trx::TrxFile empty; - EXPECT_THROW(empty.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(empty.add_dpv_from_tsf("signal", "float32", input_path.string()), trx::TrxFormatError); trx::TrxFile no_dir(4, 2); set_streamline_lengths(no_dir.streamlines.get(), {2, 2}); // Intentional white-box access: there is no public API to construct a TrxFile // with valid streamlines but without an uncompressed folder. This test verifies // that add_dpv_from_tsf fails in that specific internal state. - no_dir._uncompressed_folder_handle.clear(); - EXPECT_THROW(no_dir.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); + no_dir.uncompressed_folder_handle().clear(); + EXPECT_THROW(no_dir.add_dpv_from_tsf("signal", "float32", input_path.string()), trx::TrxIOError); } TEST(TrxFileIo, export_dpv_to_tsf_success) { @@ -831,11 +831,11 @@ TEST(TrxFileIo, export_dpv_to_tsf_errors) { const fs::path tmp_dir = make_temp_test_dir("trx_export_tsf_err"); const fs::path output_path = tmp_dir / "signal.tsf"; - EXPECT_THROW(trx.export_dpv_to_tsf("", output_path.string(), "1"), std::invalid_argument); - EXPECT_THROW(trx.export_dpv_to_tsf("signal", output_path.string(), ""), std::invalid_argument); - EXPECT_THROW(trx.export_dpv_to_tsf("signal", output_path.string(), "1", "int32"), std::invalid_argument); - EXPECT_THROW(trx.export_dpv_to_tsf("missing", output_path.string(), "1"), std::runtime_error); + EXPECT_THROW(trx.export_dpv_to_tsf("", output_path.string(), "1"), trx::TrxArgumentError); + EXPECT_THROW(trx.export_dpv_to_tsf("signal", output_path.string(), ""), trx::TrxArgumentError); + EXPECT_THROW(trx.export_dpv_to_tsf("signal", output_path.string(), "1", "int32"), trx::TrxDTypeError); + EXPECT_THROW(trx.export_dpv_to_tsf("missing", output_path.string(), "1"), trx::TrxFormatError); trx::TrxFile empty; - EXPECT_THROW(empty.export_dpv_to_tsf("signal", output_path.string(), "1"), std::runtime_error); + EXPECT_THROW(empty.export_dpv_to_tsf("signal", output_path.string(), "1"), trx::TrxFormatError); } diff --git a/tests/test_trx_mmap.cpp b/tests/test_trx_mmap.cpp index 7eb443b..09e8224 100644 --- a/tests/test_trx_mmap.cpp +++ b/tests/test_trx_mmap.cpp @@ -351,12 +351,12 @@ TEST(TrxFileMemmap, detect_positions_scalar_type_fallback) { ASSERT_TRUE(out.is_open()); out.close(); - EXPECT_THROW(trx::detect_positions_scalar_type(invalid_dir.string(), TrxScalarType::Float64), std::invalid_argument); + EXPECT_THROW(trx::detect_positions_scalar_type(invalid_dir.string(), TrxScalarType::Float64), trx::TrxError); } TEST(TrxFileMemmap, detect_positions_scalar_type_missing_path) { const fs::path missing = fs::path(make_temp_test_dir("trx_scalar_missing")) / "nope"; - EXPECT_THROW(trx::detect_positions_scalar_type(missing.string(), TrxScalarType::Float32), std::runtime_error); + EXPECT_THROW(trx::detect_positions_scalar_type(missing.string(), TrxScalarType::Float32), trx::TrxError); } TEST(TrxFileMemmap, open_zip_for_read_generic_fallback) { @@ -416,36 +416,36 @@ TEST(TrxFileMemmap, __split_ext_with_dimensionality) { { try { output = trx::detail::_split_ext_with_dimensionality(fn3); - } catch (const std::invalid_argument &e) { + } catch (const trx::TrxError &e) { EXPECT_STREQ("Invalid filename", e.what()); throw; } }, - std::invalid_argument); + trx::TrxFormatError); const std::string fn4 = "mean_fa.5.4.int32"; EXPECT_THROW( { try { output = trx::detail::_split_ext_with_dimensionality(fn4); - } catch (const std::invalid_argument &e) { + } catch (const trx::TrxError &e) { EXPECT_STREQ("Invalid filename", e.what()); throw; } }, - std::invalid_argument); + trx::TrxFormatError); const std::string fn5 = "mean_fa.fa"; EXPECT_THROW( { try { output = trx::detail::_split_ext_with_dimensionality(fn5); - } catch (const std::invalid_argument &e) { - EXPECT_STREQ("Unsupported file extension", e.what()); + } catch (const trx::TrxDTypeError &e) { + EXPECT_TRUE(std::string(e.what()).find("Unsupported file extension") != std::string::npos); throw; } }, - std::invalid_argument); + trx::TrxDTypeError); } // Mirrors trx/tests/test_memmap.py::test__compute_lengths. @@ -711,7 +711,7 @@ TEST(TrxFileMemmap, load_missing_trx_throws) { const auto memmap_dir = resolve_memmap_test_data_dir(root); const fs::path missing_trx = memmap_dir / "dontexist.trx"; - EXPECT_THROW(trx::TrxReader(missing_trx.string()), std::runtime_error); + EXPECT_THROW(trx::TrxReader(missing_trx.string()), trx::TrxError); } // validates C++ TrxFile initialization. @@ -841,7 +841,7 @@ TEST(TrxFileMemmap, query_aabb_rejects_bad_aabb_size) { std::array max_corner{1.0f, 1.0f, 1.0f}; std::vector> bad_aabbs(1); - EXPECT_THROW(trx->query_aabb(min_corner, max_corner, &bad_aabbs), std::invalid_argument); + EXPECT_THROW(trx->query_aabb(min_corner, max_corner, &bad_aabbs), trx::TrxError); trx->close(); } @@ -888,7 +888,7 @@ TEST(TrxFileMemmap, subset_streamlines_empty) { TEST(TrxFileMemmap, subset_streamlines_out_of_range) { auto trx = create_small_trx(); std::vector ids{99}; - EXPECT_THROW(trx->subset_streamlines(ids), std::invalid_argument); + EXPECT_THROW(trx->subset_streamlines(ids), trx::TrxError); trx->close(); } @@ -926,10 +926,10 @@ TEST(TrxFileMemmap, dpg_api_invalid_inputs) { auto trx = create_small_trx(); std::vector values{1.0f, 2.0f, 3.0f}; - EXPECT_THROW(trx->add_dpg_from_vector("", "dpg", "float32", values), std::invalid_argument); - EXPECT_THROW(trx->add_dpg_from_vector("g1", "", "float32", values), std::invalid_argument); - EXPECT_THROW(trx->add_dpg_from_vector("g1", "dpg", "int8", values), std::invalid_argument); - EXPECT_THROW(trx->add_dpg_from_vector("g1", "dpg", "float32", values, 2, 2), std::invalid_argument); + EXPECT_THROW(trx->add_dpg_from_vector("", "dpg", "float32", values), trx::TrxError); + EXPECT_THROW(trx->add_dpg_from_vector("g1", "", "float32", values), trx::TrxError); + EXPECT_THROW(trx->add_dpg_from_vector("g1", "dpg", "int8", values), trx::TrxError); + EXPECT_THROW(trx->add_dpg_from_vector("g1", "dpg", "float32", values, 2, 2), trx::TrxError); trx->close(); } diff --git a/tests/test_trx_trxfile.cpp b/tests/test_trx_trxfile.cpp index 2336535..28a9fb8 100644 --- a/tests/test_trx_trxfile.cpp +++ b/tests/test_trx_trxfile.cpp @@ -2,9 +2,7 @@ #include #include -#define private public #include -#undef private #include #include @@ -417,7 +415,7 @@ TEST(TrxFileTpp, NormalizeForSaveRejectsNonMonotonicOffsets) { src->streamlines->_offsets(1) = 5; src->streamlines->_offsets(2) = 4; - EXPECT_THROW(src->normalize_for_save(), std::runtime_error); + EXPECT_THROW(src->normalize_for_save(), trx::TrxError); src->close(); @@ -456,7 +454,7 @@ TEST(TrxFileTpp, LoadFromDirectoryMissingHeader) { std::ofstream f(dummy.string(), std::ios::binary); f.close(); - EXPECT_THROW(TrxFile::load_from_directory(tmp_dir.string()), std::runtime_error); + EXPECT_THROW(TrxFile::load_from_directory(tmp_dir.string()), trx::TrxError); std::error_code ec; fs::remove_all(tmp_dir, ec);