diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6b461d5..45118c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -99,13 +99,15 @@ jobs: run: | lcov --capture --directory build \ --rc geninfo_unexecuted_blocks=1 \ - --ignore-errors mismatch,inconsistent,unsupported,format \ + --ignore-errors mismatch,mismatch,inconsistent,inconsistent,unsupported,format,format \ --output-file coverage.info lcov --remove coverage.info "/opt/homebrew/*" "/Applications/Xcode_*.app/*" \ "/Users/runner/work/trx-cpp/trx-cpp/third_party/*" \ - --ignore-errors mismatch,inconsistent,unsupported,format \ + "/Users/runner/work/trx-cpp/trx-cpp/examples/*" \ + "/Users/runner/work/trx-cpp/trx-cpp/bench/*" \ + --ignore-errors mismatch,mismatch,inconsistent,inconsistent,unsupported,format,format,unused \ --output-file coverage.info - lcov --summary coverage.info --ignore-errors mismatch,inconsistent,unsupported,format + lcov --summary coverage.info --ignore-errors mismatch,mismatch,inconsistent,inconsistent,unsupported,format,format - name: Upload coverage to Codecov (macOS) if: runner.os == 'macOS' diff --git a/.github/workflows/trx-cpp-tests.yml b/.github/workflows/trx-cpp-tests.yml index 3ad1b81..03656f3 100644 --- a/.github/workflows/trx-cpp-tests.yml +++ b/.github/workflows/trx-cpp-tests.yml @@ -66,9 +66,15 @@ jobs: run: | lcov --capture --directory build \ --rc geninfo_unexecuted_blocks=1 \ - --ignore-errors mismatch \ + --ignore-errors mismatch,mismatch,inconsistent,inconsistent,unsupported,format,format \ --output-file coverage.info - lcov --summary coverage.info + lcov --remove coverage.info \ + "*/third_party/*" \ + "*/examples/*" \ + "*/bench/*" \ + --ignore-errors mismatch,mismatch,inconsistent,inconsistent,unsupported,format,format,unused \ + --output-file coverage.info + lcov --summary coverage.info --ignore-errors mismatch,mismatch,inconsistent,inconsistent,unsupported,format,format - name: Upload to Codecov uses: codecov/codecov-action@v4 diff --git a/.gitignore b/.gitignore index 43bcab0..66a6152 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ **/*.cmake **/Makefile **/build +**build-release **/bin **/load_trx **/tests/memmap @@ -15,6 +16,8 @@ test_package/CMakeUserPresets.json syntax.log docs/_build/ docs/api/ - +.DS_Store test_package/build -test_package/CMakeUserPresets.json \ No newline at end of file +test_package/CMakeUserPresets.json +test-data/* +bench/results* \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 2dc74c4..cf0f267 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,6 +24,7 @@ endif() option(TRX_USE_CONAN "Should Conan package manager be used?" OFF) option(TRX_BUILD_TESTS "Build trx tests" OFF) option(TRX_BUILD_EXAMPLES "Build trx example commandline programs" ON) +option(TRX_BUILD_BENCHMARKS "Build trx benchmarks" OFF) option(TRX_ENABLE_CLANG_TIDY "Run clang-tidy during builds" OFF) option(TRX_ENABLE_INSTALL "Install trx-cpp targets" ${TRX_IS_TOP_LEVEL}) option(TRX_BUILD_DOCS "Build API documentation with Doxygen/Sphinx" OFF) @@ -68,7 +69,11 @@ elseif(TARGET zip::zip) else() message(FATAL_ERROR "No suitable libzip target (expected libzip::libzip or zip::zip)") endif() -find_package(Eigen3 CONFIG QUIET) +# Prefer Eigen3_ROOT so -DEigen3_ROOT=/path/to/eigen-3.4 is used over system Eigen +if(Eigen3_ROOT) + list(PREPEND CMAKE_PREFIX_PATH "${Eigen3_ROOT}") +endif() +find_package(Eigen3 3.4 CONFIG QUIET) if (NOT Eigen3_FOUND) find_package(Eigen3 REQUIRED) # try module mode endif() @@ -148,6 +153,21 @@ if(TRX_BUILD_TESTS) endif() endif() +if(TRX_BUILD_BENCHMARKS) + find_package(benchmark CONFIG QUIET) + if(NOT benchmark_FOUND) + set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE) + set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable benchmark gtest" FORCE) + FetchContent_Declare( + benchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG v1.8.3 + ) + FetchContent_MakeAvailable(benchmark) + endif() + add_subdirectory(bench) +endif() + if(TRX_ENABLE_NIFTI) find_package(ZLIB REQUIRED) add_library(trx-nifti diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt new file mode 100644 index 0000000..e02d21d --- /dev/null +++ b/bench/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(bench_trx_realdata bench_trx_realdata.cpp) +target_link_libraries(bench_trx_realdata PRIVATE trx benchmark::benchmark) +target_compile_features(bench_trx_realdata PRIVATE cxx_std_17) diff --git a/bench/bench_trx_realdata.cpp b/bench/bench_trx_realdata.cpp new file mode 100644 index 0000000..035916d --- /dev/null +++ b/bench/bench_trx_realdata.cpp @@ -0,0 +1,1150 @@ +// Benchmark TRX streaming workloads for realistic datasets. +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__unix__) || defined(__APPLE__) +#include +#include +#include +#endif + +namespace { +using Eigen::half; + +std::string g_reference_trx_path; + +constexpr float kMinLengthMm = 20.0f; +constexpr float kMaxLengthMm = 500.0f; +constexpr float kStepMm = 2.0f; +constexpr float kSlabThicknessMm = 5.0f; +constexpr size_t kSlabCount = 20; + +constexpr std::array kStreamlineCounts = {100000, 500000, 1000000, 5000000, 10000000}; + +struct Fov { + float min_x; + float max_x; + float min_y; + float max_y; + float min_z; + float max_z; +}; + +constexpr Fov kFov{-70.0f, 70.0f, -108.0f, 79.0f, -60.0f, 75.0f}; + +enum class GroupScenario : int { None = 0, Bundles = 1, Connectome = 2 }; + +constexpr size_t kBundleCount = 80; +constexpr std::array kConnectomeAtlasSizes = {80, 400, 1000}; +constexpr size_t kConnectomeTotalGroups = 1480; // sum of atlas sizes + +std::string make_temp_path(const std::string &prefix) { + static std::atomic counter{0}; + const auto id = counter.fetch_add(1, std::memory_order_relaxed); + const auto dir = std::filesystem::temp_directory_path(); + return (dir / (prefix + "_" + std::to_string(id) + ".trx")).string(); +} + +std::string make_work_dir_name(const std::string &prefix) { + static std::atomic counter{0}; + const auto id = counter.fetch_add(1, std::memory_order_relaxed); +#if defined(__unix__) || defined(__APPLE__) + const auto pid = static_cast(getpid()); +#else + const auto pid = static_cast(0); +#endif + const auto dir = std::filesystem::current_path(); + return (dir / (prefix + "_" + std::to_string(pid) + "_" + std::to_string(id))).string(); +} + +void register_cleanup(const std::string &path); + +size_t file_size_bytes(const std::string &path) { + std::error_code ec; + if (!trx::fs::exists(path, ec)) { + return 0; + } + if (trx::fs::is_directory(path, ec)) { + size_t total = 0; + for (trx::fs::recursive_directory_iterator it(path, ec), end; it != end; it.increment(ec)) { + if (ec) { + break; + } + if (!it->is_regular_file(ec)) { + continue; + } + total += static_cast(trx::fs::file_size(it->path(), ec)); + if (ec) { + break; + } + } + return total; + } + return static_cast(trx::fs::file_size(path, ec)); +} + +double get_max_rss_kb() { +#if defined(__unix__) || defined(__APPLE__) + rusage usage{}; + if (getrusage(RUSAGE_SELF, &usage) != 0) { + return 0.0; + } +#if defined(__APPLE__) + return static_cast(usage.ru_maxrss) / 1024.0; +#else + return static_cast(usage.ru_maxrss); +#endif +#else + return 0.0; +#endif +} + +size_t parse_env_size(const char *name, size_t default_value) { + const char *raw = std::getenv(name); + if (!raw || raw[0] == '\0') { + return default_value; + } + char *end = nullptr; + const unsigned long long value = std::strtoull(raw, &end, 10); + if (end == raw) { + return default_value; + } + return static_cast(value); +} + +bool parse_env_bool(const char *name, bool default_value) { + const char *raw = std::getenv(name); + if (!raw || raw[0] == '\0') { + return default_value; + } + return std::string(raw) != "0"; +} + +int parse_env_int(const char *name, int default_value) { + const char *raw = std::getenv(name); + if (!raw || raw[0] == '\0') { + return default_value; + } + char *end = nullptr; + const long value = std::strtol(raw, &end, 10); + if (end == raw) { + return default_value; + } + return static_cast(value); +} + +bool is_core_profile() { + const char *raw = std::getenv("TRX_BENCH_PROFILE"); + return raw && std::string(raw) == "core"; +} + +bool include_bundles_in_core_profile() { + return parse_env_bool("TRX_BENCH_CORE_INCLUDE_BUNDLES", false); +} + +size_t core_dpv_max_streamlines() { + return parse_env_size("TRX_BENCH_CORE_DPV_MAX_STREAMLINES", 1000000); +} + +size_t core_zip_max_streamlines() { + return parse_env_size("TRX_BENCH_CORE_ZIP_MAX_STREAMLINES", 1000000); +} + +std::vector group_cases_for_benchmarks() { + std::vector groups = {static_cast(GroupScenario::None)}; + if (!is_core_profile() || include_bundles_in_core_profile()) { + groups.push_back(static_cast(GroupScenario::Bundles)); + } + if (parse_env_bool("TRX_BENCH_INCLUDE_CONNECTOME", !is_core_profile())) { + groups.push_back(static_cast(GroupScenario::Connectome)); + } + return groups; +} + +size_t group_count_for(GroupScenario scenario) { + switch (scenario) { + case GroupScenario::Bundles: + return kBundleCount; + case GroupScenario::Connectome: + return kConnectomeTotalGroups; + case GroupScenario::None: + default: + return 0; + } +} + +// Compute position buffer size based on streamline count. +// For slow storage (spinning disks, network filesystems), set TRX_BENCH_BUFFER_MULTIPLIER +// to 2-8 to reduce I/O frequency at the cost of higher memory usage. +// Example: multiplier=4 scales 256 MB → 1 GB for 1M streamlines. +std::size_t buffer_bytes_for_streamlines(std::size_t streamlines) { + std::size_t base_bytes; + if (streamlines >= 5000000) { + base_bytes = 2ULL * 1024ULL * 1024ULL * 1024ULL; // 2 GB + } else if (streamlines >= 1000000) { + base_bytes = 256ULL * 1024ULL * 1024ULL; // 256 MB + } else { + base_bytes = 16ULL * 1024ULL * 1024ULL; // 16 MB + } + + // Allow scaling buffer sizes for slower storage (HDD, NFS) to amortize I/O latency + const size_t multiplier = std::max(1, parse_env_size("TRX_BENCH_BUFFER_MULTIPLIER", 1)); + return base_bytes * multiplier; +} + +std::vector streamlines_for_benchmarks() { + const size_t only = parse_env_size("TRX_BENCH_ONLY_STREAMLINES", 0); + if (only > 0) { + return {only}; + } + const size_t max_val = parse_env_size("TRX_BENCH_MAX_STREAMLINES", 10000000); + std::vector counts = {10000000, 5000000, 1000000, 500000, 100000}; + counts.erase(std::remove_if(counts.begin(), counts.end(), [&](size_t v) { return v > max_val; }), counts.end()); + if (counts.empty()) { + counts.push_back(max_val); + } + return counts; +} + +void log_bench_start(const std::string &name, const std::string &details) { + if (!parse_env_bool("TRX_BENCH_LOG", false)) { + return; + } + std::cerr << "[trx-bench] start " << name << " " << details << std::endl; +} + +void log_bench_end(const std::string &name, const std::string &details) { + if (!parse_env_bool("TRX_BENCH_LOG", false)) { + return; + } + std::cerr << "[trx-bench] end " << name << " " << details << std::endl; +} + +const std::vector &group_names_for(GroupScenario scenario) { + static const std::vector empty; + static const std::vector bundle_names = []() { + std::vector names; + names.reserve(kBundleCount); + for (size_t i = 1; i <= kBundleCount; ++i) { + names.push_back("Bundle" + std::to_string(i)); + } + return names; + }(); + static const std::vector connectome_names = []() { + std::vector names; + names.reserve(kConnectomeTotalGroups); + for (size_t a = 0; a < kConnectomeAtlasSizes.size(); ++a) { + for (size_t r = 1; r <= kConnectomeAtlasSizes[a]; ++r) { + names.push_back("atlas" + std::to_string(a + 1) + "_region" + std::to_string(r)); + } + } + return names; + }(); + switch (scenario) { + case GroupScenario::Bundles: + return bundle_names; + case GroupScenario::Connectome: + return connectome_names; + case GroupScenario::None: + default: + return empty; + } +} + +std::vector build_prefix_ids(size_t num_streamlines) { + if (num_streamlines > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Too many streamlines for uint32 index space."); + } + std::vector ids; + ids.reserve(num_streamlines); + for (size_t i = 0; i < num_streamlines; ++i) { + ids.push_back(static_cast(i)); + } + return ids; +} + +void assign_groups_to_trx(trx::TrxFile &trx, GroupScenario scenario, size_t streamlines) { + const auto group_count = group_count_for(scenario); + const auto &group_names = group_names_for(scenario); + if (group_count == 0) { + return; + } + + if (scenario == GroupScenario::Connectome) { + std::vector> group_indices(kConnectomeTotalGroups); + std::mt19937 rng(42); + std::uniform_real_distribution coin(0.0f, 1.0f); + + for (size_t i = 0; i < streamlines; ++i) { + size_t group_offset = 0; + for (size_t a = 0; a < kConnectomeAtlasSizes.size(); ++a) { + const size_t n_regions = kConnectomeAtlasSizes[a]; + const size_t picks = (coin(rng) < 0.1f) ? 3 : 2; + std::unordered_set chosen; + std::uniform_int_distribution region_dist(0, n_regions - 1); + while (chosen.size() < picks) { + chosen.insert(region_dist(rng)); + } + for (size_t r : chosen) { + group_indices[group_offset + r].push_back(static_cast(i)); + } + group_offset += n_regions; + } + } + for (size_t g = 0; g < kConnectomeTotalGroups; ++g) { + trx.add_group_from_indices(group_names[g], group_indices[g]); + } + } else { + for (size_t g = 0; g < group_count; ++g) { + std::vector group_indices; + group_indices.reserve(streamlines / group_count + 1); + for (size_t i = g; i < streamlines; i += group_count) { + group_indices.push_back(static_cast(i)); + } + trx.add_group_from_indices(group_names[g], group_indices); + } + } +} + +std::unique_ptr> build_prefix_subset_trx(size_t streamlines, + GroupScenario scenario, + bool add_dps, + bool add_dpv) { + if (g_reference_trx_path.empty()) { + throw std::runtime_error("Reference TRX path not set."); + } + auto ref_trx = trx::load(g_reference_trx_path); + const size_t ref_count = ref_trx->num_streamlines(); + if (streamlines > ref_count) { + throw std::runtime_error("Requested " + std::to_string(streamlines) + + " streamlines but reference only has " + std::to_string(ref_count)); + } + + const auto ids = build_prefix_ids(streamlines); + auto out = ref_trx->subset_streamlines(ids, false); + + // Benchmark scenario owns grouping; drop any inherited groups first. + out->groups.clear(); + if (!out->_uncompressed_folder_handle.empty()) { + std::error_code ec; + std::filesystem::remove_all(std::filesystem::path(out->_uncompressed_folder_handle) / "groups", ec); + } + + if (!add_dps) { + out->data_per_streamline.clear(); + if (!out->_uncompressed_folder_handle.empty()) { + std::error_code ec; + std::filesystem::remove_all(std::filesystem::path(out->_uncompressed_folder_handle) / "dps", ec); + } + } else if (out->data_per_streamline.empty()) { + std::vector ones(streamlines, 1.0f); + out->add_dps_from_vector("sift_weights", "float32", ones); + } + + if (add_dpv) { + std::vector dpv(out->num_vertices(), 0.5f); + out->add_dpv_from_vector("dpv_random", "float32", dpv); + } else { + out->data_per_vertex.clear(); + if (!out->_uncompressed_folder_handle.empty()) { + std::error_code ec; + std::filesystem::remove_all(std::filesystem::path(out->_uncompressed_folder_handle) / "dpv", ec); + } + } + + assign_groups_to_trx(*out, scenario, streamlines); + + return out; +} + +struct TrxWriteStats { + double write_ms = 0.0; + double file_size_bytes = 0.0; +}; + +struct RssSample { + double elapsed_ms = 0.0; + double rss_kb = 0.0; + std::string phase; +}; + +struct FileSizeScenario { + size_t streamlines = 0; + bool add_dps = false; + bool add_dpv = false; + zip_uint32_t compression = ZIP_CM_STORE; +}; + +std::mutex g_rss_samples_mutex; + +void append_rss_samples(const FileSizeScenario &scenario, const std::vector &samples) { + if (samples.empty()) { + return; + } + const char *path = std::getenv("TRX_RSS_SAMPLES_PATH"); + if (!path || path[0] == '\0') { + return; + } + std::lock_guard lock(g_rss_samples_mutex); + std::ofstream out(path, std::ios::app); + if (!out.is_open()) { + return; + } + + out << "{" + << "\"streamlines\":" << scenario.streamlines << "," + << "\"dps\":" << (scenario.add_dps ? 1 : 0) << "," + << "\"dpv\":" << (scenario.add_dpv ? 1 : 0) << "," + << "\"compression\":" << (scenario.compression == ZIP_CM_DEFLATE ? 1 : 0) << "," + << "\"samples\":["; + for (size_t i = 0; i < samples.size(); ++i) { + if (i > 0) { + out << ","; + } + out << "{" + << "\"elapsed_ms\":" << samples[i].elapsed_ms << "," + << "\"rss_kb\":" << samples[i].rss_kb << "," + << "\"phase\":\"" << samples[i].phase << "\"" + << "}"; + } + out << "]}\n"; +} + +std::mutex g_cleanup_mutex; +std::vector g_cleanup_paths; +pid_t g_cleanup_owner_pid = 0; +bool g_cleanup_only_on_success = true; +bool g_run_success = false; + +void cleanup_temp_paths() { + if (g_cleanup_only_on_success && !g_run_success) { + return; + } + if (g_cleanup_owner_pid != 0 && getpid() != g_cleanup_owner_pid) { + return; + } + std::error_code ec; + for (const auto &p : g_cleanup_paths) { + std::filesystem::remove_all(p, ec); + } +} + +void register_cleanup(const std::string &path) { + static bool registered = false; + { + std::lock_guard lock(g_cleanup_mutex); + if (g_cleanup_owner_pid == 0) { + g_cleanup_owner_pid = getpid(); + } + g_cleanup_paths.push_back(path); + } + if (!registered) { + registered = true; + std::atexit(cleanup_temp_paths); + } +} + +TrxWriteStats run_trx_file_size(size_t streamlines, + bool add_dps, + bool add_dpv, + zip_uint32_t compression) { + const size_t progress_every = parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 0); + + const bool collect_rss = std::getenv("TRX_RSS_SAMPLES_PATH") != nullptr; + const size_t sample_every = parse_env_size("TRX_RSS_SAMPLE_EVERY", 50000); + const int sample_interval_ms = parse_env_int("TRX_RSS_SAMPLE_MS", 500); + std::vector samples; + std::mutex samples_mutex; + const auto bench_start = std::chrono::steady_clock::now(); + auto record_sample = [&](const std::string &phase) { + if (!collect_rss) { + return; + } + const auto now = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed = now - bench_start; + std::lock_guard lock(samples_mutex); + samples.push_back({elapsed.count(), get_max_rss_kb(), phase}); + }; + + auto trx_subset = build_prefix_subset_trx(streamlines, GroupScenario::None, add_dps, add_dpv); + if (progress_every > 0) { + std::cerr << "[trx-bench] progress file_size streamlines=" << streamlines << " / " << streamlines << std::endl; + } + if (collect_rss && sample_every > 0) { + record_sample("generate"); + } + + const std::string out_path = make_temp_path("trx_size"); + record_sample("before_finalize"); + + std::atomic sampling{false}; + std::thread sampler; + if (collect_rss) { + sampling.store(true, std::memory_order_relaxed); + sampler = std::thread([&]() { + while (sampling.load(std::memory_order_relaxed)) { + record_sample("finalize"); + std::this_thread::sleep_for(std::chrono::milliseconds(sample_interval_ms)); + } + }); + } + + trx::TrxSaveOptions save_opts; + save_opts.compression_standard = compression; + const auto start = std::chrono::steady_clock::now(); + trx_subset->save(out_path, save_opts); + const auto end = std::chrono::steady_clock::now(); + trx_subset->close(); + + if (collect_rss) { + sampling.store(false, std::memory_order_relaxed); + if (sampler.joinable()) { + sampler.join(); + } + } + record_sample("after_finalize"); + + TrxWriteStats stats; + stats.write_ms = std::chrono::duration(end - start).count(); + std::error_code size_ec; + const auto size = std::filesystem::file_size(out_path, size_ec); + stats.file_size_bytes = size_ec ? 0.0 : static_cast(size); + std::error_code ec; + std::filesystem::remove(out_path, ec); + + if (collect_rss) { + FileSizeScenario scenario; + scenario.streamlines = streamlines; + scenario.add_dps = add_dps; + scenario.add_dpv = add_dpv; + scenario.compression = compression; + append_rss_samples(scenario, samples); + } + return stats; +} + +struct TrxOnDisk { + std::string path; + size_t streamlines = 0; + size_t vertices = 0; + double shard_merge_ms = 0.0; + size_t shard_processes = 1; +}; + +TrxOnDisk build_trx_file_on_disk_single(size_t streamlines, + GroupScenario scenario, + bool add_dps, + bool add_dpv, + zip_uint32_t compression, + const std::string &out_path_override = "") { + const size_t progress_every = parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 0); + + // Fast path: For 10M (full reference) WITHOUT DPV, just copy and add groups + // DPV requires 1.3B vertices × 8 bytes (vector + internal copy) = 10 GB extra memory! + if (g_reference_trx_path.empty()) { + throw std::runtime_error("Reference TRX path not set."); + } + auto ref_trx_check = trx::load(g_reference_trx_path); + const size_t ref_count = ref_trx_check->num_streamlines(); + const bool is_full_reference = (streamlines == ref_count); + ref_trx_check.reset(); // Release mmap + + // Fast path for full reference WITHOUT DPV: work with reference directly + // DPV + 10M requires 40-50 GB due to: mmap(7.6) + dpv_vec(4.8) + dpv_mmap(4.8) + save() overhead(20+) + if (is_full_reference && !add_dpv) { + log_bench_start("build_trx_copy_fast", "streamlines=" + std::to_string(streamlines)); + + // Filesystem copy reference (disk I/O only) + const std::string temp_copy = make_temp_path("trx_ref_copy"); + std::error_code copy_ec; + std::filesystem::copy_file(g_reference_trx_path, temp_copy, + std::filesystem::copy_options::overwrite_existing, copy_ec); + if (copy_ec) { + throw std::runtime_error("Failed to copy reference: " + copy_ec.message()); + } + + // Load and modify (just groups, no DPV) + auto trx = trx::load(temp_copy); + + // Add groups + assign_groups_to_trx(*trx, scenario, streamlines); + + // Note: DPV for 10M handled by sampling path (skipped by default due to 40-50 GB memory) + + // Save to output + const std::string out_path = out_path_override.empty() ? make_temp_path("trx_input") : out_path_override; + trx::TrxSaveOptions save_opts; + save_opts.compression_standard = compression; + trx->save(out_path, save_opts); + const size_t total_vertices = trx->num_vertices(); + trx->close(); + + std::filesystem::remove(temp_copy, copy_ec); + + if (out_path_override.empty()) { + register_cleanup(out_path); + } + + log_bench_end("build_trx_copy_fast", "streamlines=" + std::to_string(streamlines)); + return {out_path, streamlines, total_vertices, 0.0, 1}; + } + + // Prefix-subset path: copy first N streamlines into a TrxFile and save. + auto trx_subset = build_prefix_subset_trx(streamlines, scenario, add_dps, add_dpv); + const size_t total_vertices = trx_subset->num_vertices(); + const std::string out_path = out_path_override.empty() ? make_temp_path("trx_input") : out_path_override; + trx::TrxSaveOptions save_opts; + save_opts.compression_standard = compression; + trx_subset->save(out_path, save_opts); + trx_subset->close(); + if (progress_every > 0 && (parse_env_bool("TRX_BENCH_CHILD_LOG", false) || parse_env_bool("TRX_BENCH_LOG", false))) { + std::cerr << "[trx-bench] progress build_trx streamlines=" << streamlines << " / " << streamlines << std::endl; + } + if (out_path_override.empty()) { + register_cleanup(out_path); + } + return {out_path, streamlines, total_vertices, 0.0, 1}; +} + +TrxOnDisk build_trx_file_on_disk(size_t streamlines, + GroupScenario scenario, + bool add_dps, + bool add_dpv, + zip_uint32_t compression) { + return build_trx_file_on_disk_single(streamlines, scenario, add_dps, add_dpv, compression); +} + +struct QueryDataset { + std::unique_ptr> trx; + std::vector> slab_mins; + std::vector> slab_maxs; +}; + +void build_slabs(std::vector> &mins, std::vector> &maxs) { + mins.clear(); + maxs.clear(); + mins.reserve(kSlabCount); + maxs.reserve(kSlabCount); + const float z_range = kFov.max_z - kFov.min_z; + for (size_t i = 0; i < kSlabCount; ++i) { + const float t = (kSlabCount == 1) ? 0.5f : static_cast(i) / static_cast(kSlabCount - 1); + const float center_z = kFov.min_z + t * z_range; + const float min_z = std::max(kFov.min_z, center_z - kSlabThicknessMm * 0.5f); + const float max_z = std::min(kFov.max_z, center_z + kSlabThicknessMm * 0.5f); + mins.push_back({kFov.min_x, kFov.min_y, min_z}); + maxs.push_back({kFov.max_x, kFov.max_y, max_z}); + } +} + +struct ScenarioParams { + size_t streamlines = 0; + GroupScenario scenario = GroupScenario::None; + bool add_dps = false; + bool add_dpv = false; +}; + +struct KeyHash { + using Key = std::tuple; + size_t operator()(const Key &key) const { + size_t h = 0; + auto hash_combine = [&](size_t v) { + h ^= v + 0x9e3779b97f4a7c15ULL + (h << 6) + (h >> 2); + }; + hash_combine(std::hash{}(std::get<0>(key))); + hash_combine(std::hash{}(std::get<1>(key))); + hash_combine(std::hash{}(std::get<2>(key))); + hash_combine(std::hash{}(std::get<3>(key))); + return h; + } +}; + +void maybe_write_query_timings(const ScenarioParams &scenario, const std::vector &timings_ms) { + static std::mutex mutex; + static std::unordered_set seen; + const KeyHash::Key key{scenario.streamlines, + static_cast(scenario.scenario), + scenario.add_dps ? 1 : 0, + scenario.add_dpv ? 1 : 0}; + + std::lock_guard lock(mutex); + if (!seen.insert(key).second) { + return; + } + + const char *env_path = std::getenv("TRX_QUERY_TIMINGS_PATH"); + const std::filesystem::path out_path = env_path ? env_path : "bench/query_timings.jsonl"; + std::error_code ec; + if (!out_path.parent_path().empty()) { + std::filesystem::create_directories(out_path.parent_path(), ec); + } + std::ofstream out(out_path, std::ios::app); + if (!out.is_open()) { + return; + } + + out << "{" + << "\"streamlines\":" << scenario.streamlines << "," + << "\"group_case\":" << static_cast(scenario.scenario) << "," + << "\"group_count\":" << group_count_for(scenario.scenario) << "," + << "\"dps\":" << (scenario.add_dps ? 1 : 0) << "," + << "\"dpv\":" << (scenario.add_dpv ? 1 : 0) << "," + << "\"slab_thickness_mm\":" << kSlabThicknessMm << "," + << "\"timings_ms\":["; + for (size_t i = 0; i < timings_ms.size(); ++i) { + if (i > 0) { + out << ","; + } + out << timings_ms[i]; + } + out << "]}\n"; +} +} // namespace + +static void BM_TrxFileSize_Float16(benchmark::State &state) { + const size_t streamlines = static_cast(state.range(0)); + const auto scenario = static_cast(state.range(1)); + const bool add_dps = state.range(2) != 0; + const bool add_dpv = state.range(3) != 0; + const bool use_zip = state.range(4) != 0; + const auto compression = use_zip ? ZIP_CM_DEFLATE : ZIP_CM_STORE; + const size_t skip_zip_at = parse_env_size("TRX_BENCH_SKIP_ZIP_AT", 5000000); + const size_t skip_dpv_at = parse_env_size("TRX_BENCH_SKIP_DPV_AT", 10000000); + const size_t skip_connectome_at = parse_env_size("TRX_BENCH_SKIP_CONNECTOME_AT", 5000000); + if (use_zip && streamlines >= skip_zip_at) { + state.SkipWithMessage("zip compression skipped for large streamlines"); + return; + } + if (add_dpv && streamlines >= skip_dpv_at) { + state.SkipWithMessage("dpv skipped: requires 40-50 GB memory (set TRX_BENCH_SKIP_DPV_AT=0 to force)"); + return; + } + if (scenario == GroupScenario::Connectome && streamlines >= skip_connectome_at) { + state.SkipWithMessage("connectome groups skipped for large streamlines (set TRX_BENCH_SKIP_CONNECTOME_AT=0 to force)"); + return; + } + log_bench_start("BM_TrxFileSize_Float16", + "streamlines=" + std::to_string(streamlines) + + " group_case=" + std::to_string(state.range(1)) + + " dps=" + std::to_string(static_cast(add_dps)) + + " dpv=" + std::to_string(static_cast(add_dpv)) + + " compression=" + std::to_string(static_cast(use_zip))); + + double total_write_ms = 0.0; + double total_file_bytes = 0.0; + double total_merge_ms = 0.0; + double total_build_ms = 0.0; + double total_merge_processes = 0.0; + for (auto _ : state) { + const auto start = std::chrono::steady_clock::now(); + const auto on_disk = + build_trx_file_on_disk(streamlines, scenario, add_dps, add_dpv, compression); + const auto end = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed = end - start; + total_build_ms += elapsed.count(); + total_merge_ms += on_disk.shard_merge_ms; + total_merge_processes += static_cast(on_disk.shard_processes); + total_write_ms += elapsed.count(); + total_file_bytes += static_cast(file_size_bytes(on_disk.path)); + } + + state.counters["streamlines"] = static_cast(streamlines); + state.counters["group_case"] = static_cast(state.range(1)); + state.counters["group_count"] = static_cast(group_count_for(scenario)); + state.counters["dps"] = add_dps ? 1.0 : 0.0; + state.counters["dpv"] = add_dpv ? 1.0 : 0.0; + state.counters["compression"] = use_zip ? 1.0 : 0.0; + state.counters["positions_dtype"] = 16.0; + state.counters["write_ms"] = total_write_ms / static_cast(state.iterations()); + state.counters["build_ms"] = total_build_ms / static_cast(state.iterations()); + if (total_merge_ms > 0.0) { + state.counters["shard_merge_ms"] = total_merge_ms / static_cast(state.iterations()); + state.counters["shard_processes"] = total_merge_processes / static_cast(state.iterations()); + } + state.counters["file_bytes"] = total_file_bytes / static_cast(state.iterations()); + state.counters["max_rss_kb"] = get_max_rss_kb(); + + log_bench_end("BM_TrxFileSize_Float16", + "streamlines=" + std::to_string(streamlines)); +} + +static void BM_TrxStream_TranslateWrite(benchmark::State &state) { + const size_t streamlines = static_cast(state.range(0)); + const auto scenario = static_cast(state.range(1)); + const bool add_dps = state.range(2) != 0; + const bool add_dpv = state.range(3) != 0; + const size_t progress_every = parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 0); + log_bench_start("BM_TrxStream_TranslateWrite", + "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1)) + + " dps=" + std::to_string(static_cast(add_dps)) + + " dpv=" + std::to_string(static_cast(add_dpv))); + + using Key = KeyHash::Key; + static std::unordered_map cache; + + const Key key{streamlines, static_cast(scenario), add_dps ? 1 : 0, add_dpv ? 1 : 0}; + if (cache.find(key) == cache.end()) { + state.PauseTiming(); + cache.emplace(key, + build_trx_file_on_disk(streamlines, scenario, add_dps, add_dpv, ZIP_CM_STORE)); + state.ResumeTiming(); + } + + const auto &dataset = cache.at(key); + if (dataset.shard_processes > 1 && dataset.shard_merge_ms > 0.0) { + state.counters["shard_merge_ms"] = dataset.shard_merge_ms; + state.counters["shard_processes"] = static_cast(dataset.shard_processes); + } + for (auto _ : state) { + const auto start = std::chrono::steady_clock::now(); + auto trx = trx::load_any(dataset.path); + const size_t chunk_bytes = parse_env_size("TRX_BENCH_CHUNK_BYTES", 1024ULL * 1024ULL * 1024ULL); + const std::string out_dir = make_work_dir_name("trx_translate_chunk"); + trx::PrepareOutputOptions prep_opts; + prep_opts.overwrite_existing = true; + const auto out_info = trx::prepare_positions_output(trx, out_dir, prep_opts); + + std::ofstream out_positions(out_info.positions_path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out_positions.is_open()) { + throw std::runtime_error("Failed to open output positions file: " + out_info.positions_path); + } + + trx.for_each_positions_chunk(chunk_bytes, + [&](trx::TrxScalarType dtype, const void *data, size_t offset, size_t count) { + (void)offset; + if (progress_every > 0 && ((offset + count) % progress_every == 0)) { + std::cerr << "[trx-bench] progress translate points=" << (offset + count) + << " / " << out_info.points << std::endl; + } + const size_t total_vals = count * 3; + if (dtype == trx::TrxScalarType::Float16) { + const auto *src = reinterpret_cast(data); + std::vector tmp(total_vals); + for (size_t i = 0; i < total_vals; ++i) { + tmp[i] = static_cast(static_cast(src[i]) + 1.0f); + } + out_positions.write(reinterpret_cast(tmp.data()), + static_cast(tmp.size() * sizeof(Eigen::half))); + } else if (dtype == trx::TrxScalarType::Float32) { + const auto *src = reinterpret_cast(data); + std::vector tmp(total_vals); + for (size_t i = 0; i < total_vals; ++i) { + tmp[i] = src[i] + 1.0f; + } + out_positions.write(reinterpret_cast(tmp.data()), + static_cast(tmp.size() * sizeof(float))); + } else { + const auto *src = reinterpret_cast(data); + std::vector tmp(total_vals); + for (size_t i = 0; i < total_vals; ++i) { + tmp[i] = src[i] + 1.0; + } + out_positions.write(reinterpret_cast(tmp.data()), + static_cast(tmp.size() * sizeof(double))); + } + }); + out_positions.flush(); + out_positions.close(); + + const std::string out_path = make_temp_path("trx_translate"); + trx::TrxSaveOptions save_opts; + save_opts.mode = trx::TrxSaveMode::Archive; + save_opts.compression_standard = ZIP_CM_STORE; + save_opts.overwrite_existing = true; + trx.save(out_path, save_opts); + trx::rm_dir(out_dir); + const auto end = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed = end - start; + state.SetIterationTime(elapsed.count()); + + std::error_code ec; + std::filesystem::remove(out_path, ec); + benchmark::DoNotOptimize(trx); + } + + state.counters["streamlines"] = static_cast(streamlines); + state.counters["group_case"] = static_cast(state.range(1)); + state.counters["group_count"] = static_cast(group_count_for(scenario)); + state.counters["dps"] = add_dps ? 1.0 : 0.0; + state.counters["dpv"] = add_dpv ? 1.0 : 0.0; + state.counters["positions_dtype"] = 16.0; + state.counters["max_rss_kb"] = get_max_rss_kb(); + + log_bench_end("BM_TrxStream_TranslateWrite", + "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1))); +} + +static void BM_TrxQueryAabb_Slabs(benchmark::State &state) { + const size_t streamlines = static_cast(state.range(0)); + const auto scenario = static_cast(state.range(1)); + const bool add_dps = state.range(2) != 0; + const bool add_dpv = state.range(3) != 0; + log_bench_start("BM_TrxQueryAabb_Slabs", + "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1)) + + " dps=" + std::to_string(static_cast(add_dps)) + + " dpv=" + std::to_string(static_cast(add_dpv))); + + using Key = KeyHash::Key; + static std::unordered_map cache; + + const Key key{streamlines, static_cast(scenario), add_dps ? 1 : 0, add_dpv ? 1 : 0}; + if (cache.find(key) == cache.end()) { + state.PauseTiming(); + QueryDataset dataset; + auto on_disk = build_trx_file_on_disk(streamlines, scenario, add_dps, add_dpv, ZIP_CM_STORE); + dataset.trx = trx::load(on_disk.path); + dataset.trx->get_or_build_streamline_aabbs(); + build_slabs(dataset.slab_mins, dataset.slab_maxs); + cache.emplace(key, std::move(dataset)); + state.ResumeTiming(); + } + + auto &dataset = cache.at(key); + for (auto _ : state) { + std::vector slab_times_ms; + slab_times_ms.reserve(kSlabCount); + + const auto start = std::chrono::steady_clock::now(); + size_t total = 0; + for (size_t i = 0; i < kSlabCount; ++i) { + const auto &min_corner = dataset.slab_mins[i]; + const auto &max_corner = dataset.slab_maxs[i]; + const auto q_start = std::chrono::steady_clock::now(); + auto subset = dataset.trx->query_aabb(min_corner, max_corner); + const auto q_end = std::chrono::steady_clock::now(); + const std::chrono::duration q_elapsed = q_end - q_start; + slab_times_ms.push_back(q_elapsed.count()); + total += subset->num_streamlines(); + subset->close(); + } + const auto end = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed = end - start; + state.SetIterationTime(elapsed.count()); + benchmark::DoNotOptimize(total); + + auto sorted = slab_times_ms; + std::sort(sorted.begin(), sorted.end()); + const auto p50 = sorted[sorted.size() / 2]; + const auto p95_idx = static_cast(std::ceil(0.95 * sorted.size())) - 1; + const auto p95 = sorted[std::min(p95_idx, sorted.size() - 1)]; + state.counters["query_p50_ms"] = p50; + state.counters["query_p95_ms"] = p95; + + ScenarioParams params; + params.streamlines = streamlines; + params.scenario = scenario; + params.add_dps = add_dps; + params.add_dpv = add_dpv; + maybe_write_query_timings(params, slab_times_ms); + } + + state.counters["streamlines"] = static_cast(streamlines); + state.counters["group_case"] = static_cast(state.range(1)); + state.counters["group_count"] = static_cast(group_count_for(scenario)); + state.counters["dps"] = add_dps ? 1.0 : 0.0; + state.counters["dpv"] = add_dpv ? 1.0 : 0.0; + state.counters["query_count"] = static_cast(kSlabCount); + state.counters["slab_thickness_mm"] = kSlabThicknessMm; + state.counters["positions_dtype"] = 16.0; + state.counters["max_rss_kb"] = get_max_rss_kb(); + + log_bench_end("BM_TrxQueryAabb_Slabs", + "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1))); +} + +static void ApplySizeArgs(benchmark::internal::Benchmark *bench) { + const std::array flags = {0, 1}; + const bool core_profile = is_core_profile(); + const size_t dpv_limit = core_dpv_max_streamlines(); + const size_t zip_limit = core_zip_max_streamlines(); + const auto counts_desc = streamlines_for_benchmarks(); + const auto groups = group_cases_for_benchmarks(); + for (const auto count : counts_desc) { + const std::vector dpv_flags = (!core_profile || count <= dpv_limit) + ? std::vector{0, 1} + : std::vector{0}; + const std::vector compression_flags = (!core_profile || count <= zip_limit) + ? std::vector{0, 1} + : std::vector{0}; + for (const auto group_case : groups) { + for (const auto dps : flags) { + for (const auto dpv : dpv_flags) { + for (const auto compression : compression_flags) { + bench->Args({static_cast(count), group_case, dps, dpv, compression}); + } + } + } + } + } +} + +static void ApplyStreamArgs(benchmark::internal::Benchmark *bench) { + const std::array flags = {0, 1}; + const bool core_profile = is_core_profile(); + const size_t dpv_limit = core_dpv_max_streamlines(); + const auto groups = group_cases_for_benchmarks(); + const auto counts_desc = streamlines_for_benchmarks(); + for (const auto count : counts_desc) { + const std::vector dpv_flags = (!core_profile || count <= dpv_limit) + ? std::vector{0, 1} + : std::vector{0}; + for (const auto group_case : groups) { + for (const auto dps : flags) { + for (const auto dpv : dpv_flags) { + bench->Args({static_cast(count), group_case, dps, dpv}); + } + } + } + } +} + +static void ApplyQueryArgs(benchmark::internal::Benchmark *bench) { + const std::array flags = {0, 1}; + const bool core_profile = is_core_profile(); + const size_t dpv_limit = core_dpv_max_streamlines(); + const auto groups = group_cases_for_benchmarks(); + const auto counts_desc = streamlines_for_benchmarks(); + for (const auto count : counts_desc) { + const std::vector dpv_flags = (!core_profile || count <= dpv_limit) + ? std::vector{0, 1} + : std::vector{0}; + for (const auto group_case : groups) { + for (const auto dps : flags) { + for (const auto dpv : dpv_flags) { + bench->Args({static_cast(count), group_case, dps, dpv}); + } + } + } + } + bench->Iterations(1); +} + +BENCHMARK(BM_TrxFileSize_Float16) + ->Apply(ApplySizeArgs) + ->Unit(benchmark::kMillisecond); + +BENCHMARK(BM_TrxStream_TranslateWrite) + ->Apply(ApplyStreamArgs) + ->UseManualTime() + ->Unit(benchmark::kMillisecond); + +BENCHMARK(BM_TrxQueryAabb_Slabs) + ->Apply(ApplyQueryArgs) + ->UseManualTime() + ->Unit(benchmark::kMillisecond); + +int main(int argc, char **argv) { + // Parse custom flags before benchmark::Initialize + bool verbose = false; + bool show_help = false; + std::string reference_trx; + + // First pass: detect custom flags + for (int i = 1; i < argc; ++i) { + const std::string arg = argv[i]; + if (arg == "--verbose" || arg == "-v") { + verbose = true; + } else if (arg == "--help-custom") { + show_help = true; + } else if (arg == "--reference-trx" && i + 1 < argc) { + reference_trx = argv[i + 1]; + ++i; // Skip next arg since it's the value + } + } + + if (show_help) { + std::cout << "\nCustom benchmark options:\n" + << " --reference-trx PATH Path to reference TRX file for sampling (REQUIRED)\n" + << " --verbose, -v Enable verbose progress logging (prints every 50k streamlines)\n" + << " Equivalent to: TRX_BENCH_LOG=1 TRX_BENCH_CHILD_LOG=1 \n" + << " TRX_BENCH_LOG_PROGRESS_EVERY=50000\n" + << " --help-custom Show this help message\n" + << "\nFor standard benchmark options, use --help\n" + << std::endl; + return 0; + } + + // Validate reference TRX path + if (reference_trx.empty()) { + std::cerr << "Error: --reference-trx flag is required\n" + << "Usage: " << argv[0] << " --reference-trx [benchmark_options]\n" + << "Use --help-custom for more information\n" << std::endl; + return 1; + } + + // Check if reference file exists + std::error_code ec; + if (!std::filesystem::exists(reference_trx, ec)) { + std::cerr << "Error: Reference TRX file not found: " << reference_trx << std::endl; + return 1; + } + + // Set global reference path + g_reference_trx_path = reference_trx; + std::cerr << "[trx-bench] Using reference TRX: " << g_reference_trx_path << std::endl; + + // Enable verbose logging if requested + if (verbose) { + setenv("TRX_BENCH_LOG", "1", 0); // Don't override if already set + setenv("TRX_BENCH_CHILD_LOG", "1", 0); + if (std::getenv("TRX_BENCH_LOG_PROGRESS_EVERY") == nullptr) { + setenv("TRX_BENCH_LOG_PROGRESS_EVERY", "50000", 1); + } + std::cerr << "[trx-bench] Verbose mode enabled (progress every " + << parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 50000) + << " streamlines)\n" << std::endl; + } + + // Second pass: remove custom flags from argv before passing to benchmark::Initialize + std::vector filtered_argv; + filtered_argv.push_back(argv[0]); // Keep program name + for (int i = 1; i < argc; ++i) { + const std::string arg = argv[i]; + if (arg == "--verbose" || arg == "-v" || arg == "--help-custom") { + continue; + } else if (arg == "--reference-trx") { + ++i; // Skip the next arg (the path value) + continue; + } + filtered_argv.push_back(argv[i]); + } + int filtered_argc = static_cast(filtered_argv.size()); + + ::benchmark::Initialize(&filtered_argc, filtered_argv.data()); + if (::benchmark::ReportUnrecognizedArguments(filtered_argc, filtered_argv.data())) { + return 1; + } + try { + ::benchmark::RunSpecifiedBenchmarks(); + g_run_success = true; + } catch (const std::exception &ex) { + std::cerr << "Benchmark failed: " << ex.what() << std::endl; + return 1; + } catch (...) { + std::cerr << "Benchmark failed with unknown exception." << std::endl; + return 1; + } + return 0; +} diff --git a/bench/plot_bench.R b/bench/plot_bench.R new file mode 100755 index 0000000..09ade27 --- /dev/null +++ b/bench/plot_bench.R @@ -0,0 +1,435 @@ +#!/usr/bin/env Rscript +# +# plot_bench.R - Plot trx-cpp benchmark results with ggplot2 +# +# Usage: +# Rscript bench/plot_bench.R [--bench-dir DIR] [--out-dir DIR] [--help] +# +# This script automatically detects benchmark result files in the bench/ +# directory and generates plots for: +# - File sizes (BM_TrxFileSize_Float16) +# - Translate/write throughput (BM_TrxStream_TranslateWrite) +# - Query performance (BM_TrxQueryAabb_Slabs) +# +# Expected input files (searched in bench-dir): +# - results*.json: Main benchmark results (Google Benchmark JSON format) +# - query_timings.jsonl: Canonical per-query timing distributions (JSONL format) +# - query_timings_*.jsonl: Legacy/suite-specific timing files (also supported) +# - rss_samples.jsonl: Memory samples over time (JSONL format, optional) +# + +suppressPackageStartupMessages({ + library(jsonlite) + library(ggplot2) + library(dplyr) + library(tidyr) + library(scales) +}) + +# Constants +GROUP_LABELS <- c( + "0" = "no groups", + "1" = "bundle groups (80)", + "2" = "connectome groups (1480)" +) + +COMPRESSION_LABELS <- c( + "0" = "store (no zip)", + "1" = "zip deflate" +) + +#' Parse command line arguments +parse_args <- function() { + args <- commandArgs(trailingOnly = TRUE) + + bench_dir <- "bench" + out_dir <- "docs/_static/benchmarks" + + i <- 1 + while (i <= length(args)) { + if (args[i] == "--bench-dir") { + bench_dir <- args[i + 1] + i <- i + 2 + } else if (args[i] == "--out-dir") { + out_dir <- args[i + 1] + i <- i + 2 + } else if (args[i] == "--help" || args[i] == "-h") { + cat("Usage: Rscript plot_bench.R [--bench-dir DIR] [--out-dir DIR]\n") + cat("\n") + cat("Options:\n") + cat(" --bench-dir DIR Directory containing benchmark JSON files (default: bench)\n") + cat(" --out-dir DIR Output directory for plots (default: docs/_static/benchmarks)\n") + cat(" --help, -h Show this help message\n") + quit(status = 0) + } else { + i <- i + 1 + } + } + + list(bench_dir = bench_dir, out_dir = out_dir) +} + +#' Convert benchmark time to milliseconds +time_to_ms <- function(bench) { + value <- bench$real_time + unit <- bench$time_unit + + multiplier <- switch(unit, + "ns" = 1e-6, + "us" = 1e-3, + "ms" = 1, + "s" = 1e3, + 1e-6 # default to nanoseconds + ) + + value * multiplier +} + +#' Extract base benchmark name +parse_base_name <- function(name) { + sub("/.*", "", name) +} + +#' Load all benchmark result JSON files from a directory +load_benchmarks <- function(bench_dir) { + json_files <- list.files(bench_dir, pattern = "^results.*\\.json$", full.names = TRUE) + + if (length(json_files) == 0) { + stop("No results*.json files found in ", bench_dir) + } + + cat("Found", length(json_files), "benchmark result file(s):\n") + for (f in json_files) { + cat(" -", basename(f), "\n") + } + + all_rows <- list() + + for (json_file in json_files) { + data <- tryCatch({ + fromJSON(json_file, simplifyDataFrame = FALSE) + }, error = function(e) { + warning("Failed to parse ", json_file, ": ", e$message) + return(NULL) + }) + + if (is.null(data)) { + next + } + + benchmarks <- data$benchmarks + + if (is.null(benchmarks) || length(benchmarks) == 0) { + warning("No benchmarks found in ", json_file) + next + } + + for (bench in benchmarks) { + name <- bench$name %||% "" + if (!grepl("^BM_", name)) next + + row <- list( + name = name, + base = parse_base_name(name), + real_time_ms = time_to_ms(bench), + streamlines = bench$streamlines %||% NA, + length_profile = bench$length_profile %||% NA, + compression = bench$compression %||% NA, + group_case = bench$group_case %||% NA, + group_count = bench$group_count %||% NA, + dps = bench$dps %||% NA, + dpv = bench$dpv %||% NA, + write_ms = bench$write_ms %||% NA, + build_ms = bench$build_ms %||% NA, + file_bytes = bench$file_bytes %||% NA, + max_rss_kb = bench$max_rss_kb %||% NA, + query_p50_ms = bench$query_p50_ms %||% NA, + query_p95_ms = bench$query_p95_ms %||% NA, + shard_merge_ms = bench$shard_merge_ms %||% NA, + shard_processes = bench$shard_processes %||% NA, + source_file = basename(json_file) + ) + + all_rows[[length(all_rows) + 1]] <- row + } + } + + if (length(all_rows) == 0) { + stop("No valid benchmarks found in any JSON file") + } + + df <- bind_rows(all_rows) + + cat("\nLoaded", nrow(df), "benchmark results\n") + cat("Benchmark types found:\n") + for (base in unique(df$base)) { + count <- sum(df$base == base) + cat(" -", base, ":", count, "results\n") + } + + df +} + +#' Plot file sizes +plot_file_sizes <- function(df, out_dir) { + sub_df <- df %>% + filter(base == "BM_TrxFileSize_Float16") %>% + filter(!is.na(file_bytes), !is.na(streamlines)) + + if (nrow(sub_df) == 0) { + cat("No BM_TrxFileSize_Float16 results found, skipping file size plot\n") + return(invisible(NULL)) + } + + sub_df <- sub_df %>% + mutate( + file_mb = file_bytes / 1e6, + compression_label = recode(as.character(compression), !!!COMPRESSION_LABELS), + group_label = recode( + as.character(ifelse(is.na(group_case), "0", as.character(group_case))), + !!!GROUP_LABELS), + dp_label = sprintf("dpv=%d, dps=%d", as.integer(dpv), as.integer(dps)) + ) + + n_group_levels <- length(unique(sub_df$group_label)) + plot_height <- if (n_group_levels > 1) 5 + 3 * n_group_levels else 7 + + p <- ggplot(sub_df, aes(x = streamlines, y = file_mb, + color = dp_label)) + + geom_line(linewidth = 0.8) + + geom_point(size = 2) + + facet_grid(group_label ~ compression_label) + + scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + + scale_y_continuous(labels = label_number()) + + labs( + title = "TRX file size vs streamlines (float16 positions)", + x = "Streamlines", + y = "File size (MB)", + color = "Data per streamline/vertex" + ) + + theme_bw() + + theme( + legend.position = "bottom", + legend.box = "vertical", + strip.background = element_rect(fill = "grey90") + ) + + out_path <- file.path(out_dir, "trx_size_vs_streamlines.png") + ggsave(out_path, p, width = 12, height = plot_height, dpi = 160) + cat("Saved:", out_path, "\n") +} + +#' Plot translate/write performance +plot_translate_write <- function(df, out_dir) { + sub_df <- df %>% + filter(base == "BM_TrxStream_TranslateWrite") %>% + filter(!is.na(real_time_ms), !is.na(streamlines)) + + if (nrow(sub_df) == 0) { + cat("No BM_TrxStream_TranslateWrite results found, skipping translate plots\n") + return(invisible(NULL)) + } + + sub_df <- sub_df %>% + mutate( + group_label = recode(as.character(group_case), !!!GROUP_LABELS), + dp_label = sprintf("dpv=%d, dps=%d", as.integer(dpv), as.integer(dps)), + rss_mb = max_rss_kb / 1024 + ) + + # Time plot + p_time <- ggplot(sub_df, aes(x = streamlines, y = real_time_ms, + color = dp_label)) + + geom_line(linewidth = 0.8) + + geom_point(size = 2) + + facet_wrap(~group_label, ncol = 2) + + scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + + scale_y_continuous(labels = label_number()) + + labs( + title = "Translate + stream write throughput", + x = "Streamlines", + y = "Time (ms)", + color = "Data per point" + ) + + theme_bw() + + theme( + legend.position = "bottom", + strip.background = element_rect(fill = "grey90") + ) + + out_path <- file.path(out_dir, "trx_translate_write_time.png") + ggsave(out_path, p_time, width = 12, height = 5, dpi = 160) + cat("Saved:", out_path, "\n") + + # RSS plot + p_rss <- ggplot(sub_df, aes(x = streamlines, y = rss_mb, + color = dp_label)) + + geom_line(linewidth = 0.8) + + geom_point(size = 2) + + facet_wrap(~group_label, ncol = 2) + + scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + + scale_y_continuous(labels = label_number()) + + labs( + title = "Translate + stream write memory usage", + x = "Streamlines", + y = "Max RSS (MB)", + color = "Data per point" + ) + + theme_bw() + + theme( + legend.position = "bottom", + strip.background = element_rect(fill = "grey90") + ) + + out_path <- file.path(out_dir, "trx_translate_write_rss.png") + ggsave(out_path, p_rss, width = 12, height = 5, dpi = 160) + cat("Saved:", out_path, "\n") +} + +#' Load query timings from JSONL file +load_query_timings <- function(jsonl_path) { + if (!file.exists(jsonl_path)) { + return(NULL) + } + + lines <- readLines(jsonl_path, warn = FALSE) + lines <- lines[nzchar(lines)] + + if (length(lines) == 0) { + return(NULL) + } + + rows <- lapply(lines, function(line) { + tryCatch({ + obj <- fromJSON(line, simplifyDataFrame = FALSE) + list( + streamlines = obj$streamlines %||% NA, + group_case = obj$group_case %||% NA, + group_count = obj$group_count %||% NA, + dps = obj$dps %||% NA, + dpv = obj$dpv %||% NA, + slab_thickness_mm = obj$slab_thickness_mm %||% NA, + timings_ms = I(list(unlist(obj$timings_ms))) + ) + }, error = function(e) NULL) + }) + + rows <- rows[!sapply(rows, is.null)] + + if (length(rows) == 0) { + return(NULL) + } + + bind_rows(rows) +} + +#' Load query timings from canonical and legacy JSONL files +load_all_query_timings <- function(bench_dir) { + canonical <- file.path(bench_dir, "query_timings.jsonl") + legacy <- list.files(bench_dir, pattern = "^query_timings.*\\.jsonl$", full.names = TRUE) + jsonl_paths <- unique(c(canonical, legacy)) + jsonl_paths <- jsonl_paths[file.exists(jsonl_paths)] + + if (length(jsonl_paths) == 0) { + return(NULL) + } + + dfs <- lapply(jsonl_paths, function(path) { + df <- load_query_timings(path) + if (is.null(df) || nrow(df) == 0) { + return(NULL) + } + df$source_file <- basename(path) + df + }) + dfs <- dfs[!sapply(dfs, is.null)] + if (length(dfs) == 0) { + return(NULL) + } + bind_rows(dfs) +} + +#' Plot query timing distributions +plot_query_timings <- function(bench_dir, out_dir, group_case = 0, dpv = 0, dps = 0) { + df <- load_all_query_timings(bench_dir) + + if (is.null(df) || nrow(df) == 0) { + cat("No query_timings*.jsonl found or empty, skipping query timing plot\n") + return(invisible(NULL)) + } + + # Filter by specified conditions + df_filtered <- df %>% + filter( + group_case == !!group_case, + dpv == !!dpv, + dps == !!dps + ) + + if (nrow(df_filtered) == 0) { + cat("No query timings matching filters (group_case=", group_case, + ", dpv=", dpv, ", dps=", dps, "), skipping plot\n", sep = "") + return(invisible(NULL)) + } + + # Expand timings into long format + timing_data <- df_filtered %>% + mutate(streamlines_label = format(streamlines, big.mark = ",")) %>% + select(streamlines, streamlines_label, timings_ms) %>% + unnest(timings_ms) %>% + group_by(streamlines, streamlines_label) %>% + mutate(query_id = row_number()) %>% + ungroup() + + # Create boxplot + group_label <- GROUP_LABELS[as.character(group_case)] + + p <- ggplot(timing_data, aes(x = streamlines_label, y = timings_ms)) + + geom_boxplot(fill = "steelblue", alpha = 0.7, outlier.size = 0.5) + + labs( + title = sprintf("Slab query timings (%s, dpv=%d, dps=%d)", + group_label, dpv, dps), + x = "Streamlines", + y = "Per-slab query time (ms)" + ) + + theme_bw() + + theme( + axis.text.x = element_text(angle = 45, hjust = 1) + ) + + out_path <- file.path(out_dir, "trx_query_slab_timings.png") + ggsave(out_path, p, width = 10, height = 6, dpi = 160) + cat("Saved:", out_path, "\n") +} + +#' Main function +main <- function() { + args <- parse_args() + + # Create output directory + dir.create(args$out_dir, recursive = TRUE, showWarnings = FALSE) + + cat("\n=== TRX-CPP Benchmark Plotting ===\n\n") + cat("Benchmark directory:", args$bench_dir, "\n") + cat("Output directory:", args$out_dir, "\n\n") + + # Load benchmark results + df <- load_benchmarks(args$bench_dir) + + cat("\n--- Generating plots ---\n\n") + + # Generate plots + plot_file_sizes(df, args$out_dir) + plot_translate_write(df, args$out_dir) + plot_query_timings(args$bench_dir, args$out_dir, group_case = 0, dpv = 0, dps = 0) + + cat("\nDone! Plots saved to:", args$out_dir, "\n") +} + +# Define null-coalescing operator +`%||%` <- function(x, y) if (is.null(x)) y else x + +# Run main if executed as script +if (!interactive()) { + main() +} diff --git a/bench/run_benchmarks.sh b/bench/run_benchmarks.sh new file mode 100755 index 0000000..52272d3 --- /dev/null +++ b/bench/run_benchmarks.sh @@ -0,0 +1,237 @@ +#!/bin/bash +# +# run_benchmarks.sh - Run trx-cpp benchmarks separately to minimize memory usage +# +# Usage: +# ./bench/run_benchmarks.sh [options] +# +# Options: +# --realdata Run real data benchmarks (bench_trx_realdata, default) +# --reference PATH Path to reference TRX file (for realdata) +# --out-dir DIR Output directory for JSON results (default: bench) +# --profile MODE Benchmark profile: core (default) or full +# --allow-synth-mp Allow synthetic multiprocessing (experimental) +# --verbose Enable verbose progress logging +# --help Show this help message +# +# Environment variables (optional): +# TRX_BENCH_BUFFER_MULTIPLIER Buffer size multiplier for slow storage (default: 1) +# TRX_BENCH_MAX_STREAMLINES Maximum streamline count to test (profile default) +# TRX_BENCH_PROCESSES Number of processes (synthetic only, default: 1) +# + +set -e # Exit on error + +# Default values +BENCH_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$BENCH_DIR")" +OUT_DIR="$BENCH_DIR" +RUN_SYNTHETIC=false +RUN_REALDATA=true +REFERENCE_TRX="$PROJECT_ROOT/test-data/10milHCP_dps-sift2.trx" +VERBOSE_FLAG="" +BUILD_DIR="$PROJECT_ROOT/build-release" +PROFILE="core" +ALLOW_SYNTH_MP=false + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --synthetic) + RUN_SYNTHETIC=true + shift + ;; + --realdata) + RUN_REALDATA=true + shift + ;; + --both) + RUN_SYNTHETIC=true + RUN_REALDATA=true + shift + ;; + --profile) + PROFILE="$2" + shift 2 + ;; + --allow-synth-mp) + ALLOW_SYNTH_MP=true + shift + ;; + --reference) + REFERENCE_TRX="$2" + shift 2 + ;; + --out-dir) + OUT_DIR="$2" + shift 2 + ;; + --verbose) + VERBOSE_FLAG="--verbose" + shift + ;; + --build-dir) + BUILD_DIR="$2" + shift 2 + ;; + --help) + head -n 20 "$0" | tail -n +2 | sed 's/^# //' | sed 's/^#//' + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +if [[ "$PROFILE" != "core" && "$PROFILE" != "full" ]]; then + echo "Error: --profile must be 'core' or 'full' (got '$PROFILE')" + exit 1 +fi + +# Create output directory +mkdir -p "$OUT_DIR" + +echo "========================================" +echo "TRX-CPP Benchmark Runner" +echo "========================================" +echo "Output directory: $OUT_DIR" +echo "Build directory: $BUILD_DIR" +echo "Run synthetic: $RUN_SYNTHETIC" +echo "Run realdata: $RUN_REALDATA" +echo "Profile: $PROFILE" +echo "Synthetic multiprocessing: $([[ "$ALLOW_SYNTH_MP" == "true" ]] && echo "experimental" || echo "disabled")" +if [[ "$RUN_REALDATA" == "true" ]]; then + echo "Reference TRX: $REFERENCE_TRX" +fi +echo "Verbose: ${VERBOSE_FLAG:-disabled}" +echo "" + +# Function to run a single benchmark +run_benchmark() { + local executable=$1 + local filter=$2 + local output_file=$3 + local extra_env=$4 + local extra_flags=$5 + + echo "----------------------------------------" + echo "Running: $(basename "$executable") --benchmark_filter=$filter" + echo "Output: $(basename "$output_file")" + echo "----------------------------------------" + + if [[ -n "$extra_env" ]]; then + eval "$extra_env" "$executable" $VERBOSE_FLAG $extra_flags \ + --benchmark_filter="$filter" \ + --benchmark_out="$output_file" \ + --benchmark_out_format=json + else + "$executable" $VERBOSE_FLAG $extra_flags \ + --benchmark_filter="$filter" \ + --benchmark_out="$output_file" \ + --benchmark_out_format=json + fi + + echo "✓ Completed: $(basename "$output_file")" + echo "" +} + +# Build profile environment defaults (users can override by exporting env vars) +if [[ "$PROFILE" == "core" ]]; then + CORE_ENV="TRX_BENCH_PROFILE=${TRX_BENCH_PROFILE:-core} TRX_BENCH_MAX_STREAMLINES=${TRX_BENCH_MAX_STREAMLINES:-5000000} TRX_BENCH_SKIP_DPV_AT=${TRX_BENCH_SKIP_DPV_AT:-1000000} TRX_BENCH_QUERY_CACHE_MAX=${TRX_BENCH_QUERY_CACHE_MAX:-5} TRX_BENCH_CORE_INCLUDE_BUNDLES=${TRX_BENCH_CORE_INCLUDE_BUNDLES:-0} TRX_BENCH_INCLUDE_CONNECTOME=${TRX_BENCH_INCLUDE_CONNECTOME:-0} TRX_BENCH_CORE_DPV_MAX_STREAMLINES=${TRX_BENCH_CORE_DPV_MAX_STREAMLINES:-1000000} TRX_BENCH_CORE_ZIP_MAX_STREAMLINES=${TRX_BENCH_CORE_ZIP_MAX_STREAMLINES:-1000000}" +else + CORE_ENV="TRX_BENCH_PROFILE=${TRX_BENCH_PROFILE:-full} TRX_BENCH_MAX_STREAMLINES=${TRX_BENCH_MAX_STREAMLINES:-10000000} TRX_BENCH_SKIP_DPV_AT=${TRX_BENCH_SKIP_DPV_AT:-10000000} TRX_BENCH_QUERY_CACHE_MAX=${TRX_BENCH_QUERY_CACHE_MAX:-10} TRX_BENCH_INCLUDE_CONNECTOME=${TRX_BENCH_INCLUDE_CONNECTOME:-1}" +fi + +SYNTH_ENV="$CORE_ENV" +if [[ "$ALLOW_SYNTH_MP" != "true" ]]; then + SYNTH_ENV="$SYNTH_ENV TRX_BENCH_PROCESSES=${TRX_BENCH_PROCESSES:-1}" +fi + +# Synthetic benchmarks +if [[ "$RUN_SYNTHETIC" == "true" ]]; then + SYNTHETIC_BIN="$BUILD_DIR/bench/bench_trx_stream" + + if [[ ! -f "$SYNTHETIC_BIN" ]]; then + echo "Error: Synthetic benchmark not found: $SYNTHETIC_BIN" + echo "Build with: cmake --build $BUILD_DIR --target bench_trx_stream" + exit 1 + fi + + echo "========================================" + echo "SYNTHETIC DATA BENCHMARKS" + echo "========================================" + echo "" + + run_benchmark "$SYNTHETIC_BIN" "BM_TrxFileSize_Float16" \ + "$OUT_DIR/results_synthetic_filesize.json" \ + "$SYNTH_ENV" + + run_benchmark "$SYNTHETIC_BIN" "BM_TrxStream_TranslateWrite" \ + "$OUT_DIR/results_synthetic_translate.json" \ + "$SYNTH_ENV" + + run_benchmark "$SYNTHETIC_BIN" "BM_TrxQueryAabb_Slabs" \ + "$OUT_DIR/results_synthetic_query.json" \ + "$SYNTH_ENV TRX_QUERY_TIMINGS_PATH=$OUT_DIR/query_timings_synthetic.jsonl" + + echo "✓ All synthetic benchmarks completed" + echo "" +fi + +# Real data benchmarks +if [[ "$RUN_REALDATA" == "true" ]]; then + REALDATA_BIN="$BUILD_DIR/bench/bench_trx_realdata" + + if [[ ! -f "$REALDATA_BIN" ]]; then + echo "Error: Real-data benchmark not found: $REALDATA_BIN" + echo "Build with: cmake --build $BUILD_DIR --target bench_trx_realdata" + exit 1 + fi + + if [[ ! -f "$REFERENCE_TRX" ]]; then + echo "Error: Reference TRX file not found: $REFERENCE_TRX" + echo "Use --reference to specify the path" + exit 1 + fi + + echo "========================================" + echo "REAL DATA BENCHMARKS" + echo "========================================" + echo "" + + # Reference flag for all realdata benchmarks + REALDATA_FLAGS="--reference-trx $REFERENCE_TRX" + REALDATA_ENV="$CORE_ENV" + + run_benchmark "$REALDATA_BIN" "BM_TrxFileSize_Float16" \ + "$OUT_DIR/results_realdata_filesize.json" \ + "$REALDATA_ENV" \ + "$REALDATA_FLAGS" + + run_benchmark "$REALDATA_BIN" "BM_TrxStream_TranslateWrite" \ + "$OUT_DIR/results_realdata_translate.json" \ + "$REALDATA_ENV" \ + "$REALDATA_FLAGS" + + run_benchmark "$REALDATA_BIN" "BM_TrxQueryAabb_Slabs" \ + "$OUT_DIR/results_realdata_query.json" \ + "$REALDATA_ENV TRX_QUERY_TIMINGS_PATH=$OUT_DIR/query_timings.jsonl" \ + "$REALDATA_FLAGS" + + echo "✓ All real-data benchmarks completed" + echo "" +fi + +echo "========================================" +echo "BENCHMARK SUMMARY" +echo "========================================" +echo "" +echo "Results saved to: $OUT_DIR" +ls -lh "$OUT_DIR"/results_*.json 2>/dev/null || echo "No result files found" +echo "" +echo "To generate plots:" +echo " Rscript bench/plot_bench.R --bench-dir $OUT_DIR --out-dir docs/_static/benchmarks" +echo "" diff --git a/codecov.yml b/codecov.yml index 62c2cc1..02ccda9 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,6 +2,7 @@ coverage: ignore: - "third_party" - "examples" + - "bench" status: project: default: diff --git a/docs/_static/benchmarks/.gitkeep b/docs/_static/benchmarks/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/_static/benchmarks/.gitkeep @@ -0,0 +1 @@ + diff --git a/docs/_static/benchmarks/trx_query_slab_timings.png b/docs/_static/benchmarks/trx_query_slab_timings.png new file mode 100644 index 0000000..b950715 Binary files /dev/null and b/docs/_static/benchmarks/trx_query_slab_timings.png differ diff --git a/docs/_static/benchmarks/trx_size_vs_streamlines.png b/docs/_static/benchmarks/trx_size_vs_streamlines.png new file mode 100644 index 0000000..f43a93f Binary files /dev/null and b/docs/_static/benchmarks/trx_size_vs_streamlines.png differ diff --git a/docs/_static/benchmarks/trx_translate_write_rss.png b/docs/_static/benchmarks/trx_translate_write_rss.png new file mode 100644 index 0000000..1a4b8bc Binary files /dev/null and b/docs/_static/benchmarks/trx_translate_write_rss.png differ diff --git a/docs/_static/benchmarks/trx_translate_write_time.png b/docs/_static/benchmarks/trx_translate_write_time.png new file mode 100644 index 0000000..cfed582 Binary files /dev/null and b/docs/_static/benchmarks/trx_translate_write_time.png differ diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst new file mode 100644 index 0000000..e10ddfa --- /dev/null +++ b/docs/benchmarks.rst @@ -0,0 +1,191 @@ +Benchmarks +========== + +This page documents the benchmarking suite and how to interpret the results. +The benchmarks are designed for realistic tractography workloads (HPC scale), +not for CI. They focus on file size, throughput, and interactive spatial queries. + +Data model +---------- + +All benchmarks synthesize smooth, slightly curved streamlines in a realistic +field of view: + +- **Lengths:** random between 20 and 500 mm (profiles skew short/medium/long) +- **Field of view:** x = [-70, 70], y = [-108, 79], z = [-60, 75] (mm, RAS+) +- **Streamline counts:** 100k, 500k, 1M, 5M, 10M +- **Groups:** none, 80 bundle groups, or 4950 connectome groups (100 regions) +- **DPV/DPS:** either present (1 value) or absent + +Positions are stored as float16 to highlight storage efficiency. + +TRX size vs streamline count +---------------------------- + +This benchmark writes TRX files with float16 positions and measures the final +on-disk size for different streamline counts. It compares short/medium/long +length profiles, DPV/DPS presence, and zip compression (store vs deflate). + +.. figure:: _static/benchmarks/trx_size_vs_streamlines.png + :alt: TRX file size vs streamlines + :align: center + + File size (MB) as a function of streamline count. + +Translate + stream write throughput +----------------------------------- + +This benchmark loads a TRX file, iterates through every streamline, translates +each point by +1 mm in x/y/z, and streams the result into a new TRX file. It +reports total wall time and max RSS so researchers can understand throughput +and memory pressure on both clusters and laptops. + +.. figure:: _static/benchmarks/trx_translate_write_time.png + :alt: Translate + stream write time + :align: center + + End-to-end time for translating and rewriting streamlines. + +.. figure:: _static/benchmarks/trx_translate_write_rss.png + :alt: Translate + stream write RSS + :align: center + + Max RSS during translate + stream write. + +Spatial slab query latency +-------------------------- + +This benchmark precomputes per-streamline AABBs and then issues 100 spatial +queries using 5 mm slabs that sweep through the tractogram volume. Each slab +query mimics a GUI slice update and records its timing so distributions can be +visualized. + +.. figure:: _static/benchmarks/trx_query_slab_timings.png + :alt: Slab query timings + :align: center + + Distribution of per-slab query latency. + +Performance characteristics +--------------------------- + +Benchmark results vary significantly based on storage performance: + +**SSD (solid-state drives):** +- **CPU-bound**: Disk writes complete faster than streamline generation +- High CPU utilization (~100%) +- Results reflect pure computational throughput + +**HDD (spinning disks):** +- **I/O-bound**: Disk writes are the bottleneck +- Low CPU utilization (~5-10%) +- Results reflect realistic workstation performance with storage latency + +Both scenarios are valuable. SSD results show the library's maximum throughput, +while HDD results show real-world performance on cost-effective storage. On +Linux, monitor I/O wait time with ``iostat -x 1`` to identify the bottleneck. + +For spinning disks or network filesystems, you may want to increase buffer sizes +to amortize I/O latency. Set ``TRX_BENCH_BUFFER_MULTIPLIER`` to use larger +buffers (e.g., ``TRX_BENCH_BUFFER_MULTIPLIER=4`` uses 4× the default buffer +sizes). + +Running the benchmarks +---------------------- + +Build and run the benchmarks, then plot results with matplotlib: + +.. code-block:: bash + + cmake -S . -B build -DTRX_BUILD_BENCHMARKS=ON + cmake --build build --target bench_trx_stream + + # Run benchmarks (this can be long for large datasets). + ./build/bench/bench_trx_stream \ + --benchmark_out=bench/results.json \ + --benchmark_out_format=json + + # For slower storage (HDD, NFS), use larger buffers: + TRX_BENCH_BUFFER_MULTIPLIER=4 \ + ./build/bench/bench_trx_stream \ + --benchmark_out=bench/results_hdd.json \ + --benchmark_out_format=json + + # Capture per-slab timings for query distributions. + TRX_QUERY_TIMINGS_PATH=bench/query_timings.jsonl \ + ./build/bench/bench_trx_stream \ + --benchmark_filter=BM_TrxQueryAabb_Slabs \ + --benchmark_out=bench/results.json \ + --benchmark_out_format=json + + # Optional: record RSS samples for file-size runs. + TRX_RSS_SAMPLES_PATH=bench/rss_samples.jsonl \ + TRX_RSS_SAMPLE_EVERY=50000 \ + TRX_RSS_SAMPLE_MS=500 \ + ./build/bench/bench_trx_stream \ + --benchmark_filter=BM_TrxFileSize_Float16 \ + --benchmark_out=bench/results.json \ + --benchmark_out_format=json + + # Generate plots into docs/_static/benchmarks. + python bench/plot_bench.py bench/results.json \ + --query-json bench/query_timings.jsonl \ + --out-dir docs/_static/benchmarks + +The query plot defaults to the "no groups, no DPV/DPS" case. Use +``--group-case``, ``--dpv``, and ``--dps`` in ``plot_bench.py`` to select other +scenarios. + +Environment variables +--------------------- + +The benchmark suite supports several environment variables for customization: + +**Multiprocessing:** + +- ``TRX_BENCH_PROCESSES`` (default: 1): Number of processes for parallel shard + generation. Recommended: number of physical cores. +- ``TRX_BENCH_MP_MIN_STREAMLINES`` (default: 1000000): Minimum streamline count + to enable multiprocessing. Below this threshold, single-process mode is used. +- ``TRX_BENCH_KEEP_SHARDS`` (default: 0): Set to 1 to preserve shard directories + after merging for debugging. +- ``TRX_BENCH_SHARD_WAIT_MS`` (default: 10000): Timeout in milliseconds for + waiting for shard completion markers. + +**Buffering (for slow storage):** + +- ``TRX_BENCH_BUFFER_MULTIPLIER`` (default: 1): Scales position and metadata + buffer sizes. Use larger values (2-8) for spinning disks or network + filesystems to reduce I/O latency. Example: multiplier=4 uses 64 MB → 256 MB + for small datasets, 256 MB → 1 GB for 1M streamlines, 2 GB → 8 GB for 5M+ + streamlines. + +**Performance tuning:** + +- ``TRX_BENCH_THREADS`` (default: hardware_concurrency): Worker threads for + streamline generation within each process. +- ``TRX_BENCH_BATCH`` (default: 1000): Streamlines per batch in the producer- + consumer queue. +- ``TRX_BENCH_QUEUE_MAX`` (default: 8): Maximum batches in flight between + producers and consumer. + +**Dataset control:** + +- ``TRX_BENCH_ONLY_STREAMLINES`` (default: 0): If nonzero, benchmark only this + streamline count instead of the full range. +- ``TRX_BENCH_MAX_STREAMLINES`` (default: 10000000): Maximum streamline count + to benchmark. Use smaller values for faster iteration. +- ``TRX_BENCH_SKIP_ZIP_AT`` (default: 5000000): Skip zip compression for + streamline counts at or above this threshold. + +**Logging and diagnostics:** + +- ``TRX_BENCH_LOG`` (default: 0): Enable benchmark progress logging to stderr. +- ``TRX_BENCH_CHILD_LOG`` (default: 0): Enable logging from child processes in + multiprocess mode. +- ``TRX_BENCH_LOG_PROGRESS_EVERY`` (default: 0): Log progress every N + streamlines. + +When running with multiprocessing, the benchmark uses +``finalize_directory_persistent()`` to write shard outputs without removing +pre-created directories, avoiding race conditions in the parallel workflow. diff --git a/docs/index.rst b/docs/index.rst index 58f0d3f..f10b93a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ tractography file format. overview building usage + benchmarks downstream_usage linting diff --git a/docs/usage.rst b/docs/usage.rst index 16860ee..04b2619 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -51,6 +51,174 @@ Write a TRX file trx.save("tracks_copy.trx", ZIP_CM_STORE); trx.close(); +Thread-safe streaming pattern +----------------------------- + +``TrxStream`` is **not** thread-safe for concurrent writes. A common pattern for +multi-core streamline generation is to use worker threads for generation and a +single writer thread (or the main thread) to append to ``TrxStream``. + +.. code-block:: cpp + + #include + #include + #include + #include + #include + + struct Batch { + std::vector>> streamlines; + }; + + std::mutex mutex; + std::condition_variable cv; + std::queue queue; + bool done = false; + + // Worker threads: generate streamlines and push batches. + auto producer = [&]() { + Batch batch; + batch.streamlines.reserve(1000); + for (int i = 0; i < 1000; ++i) { + std::vector> points = {/* ... generate ... */}; + batch.streamlines.push_back(std::move(points)); + } + { + std::lock_guard lock(mutex); + queue.push(std::move(batch)); + } + cv.notify_one(); + }; + + // Writer thread (single): pop batches and push into TrxStream. + trx::TrxStream stream("float16"); + auto consumer = [&]() { + for (;;) { + std::unique_lock lock(mutex); + cv.wait(lock, [&]() { return done || !queue.empty(); }); + if (queue.empty() && done) { + return; + } + Batch batch = std::move(queue.front()); + queue.pop(); + lock.unlock(); + + for (const auto &points : batch.streamlines) { + stream.push_streamline(points); + } + } + }; + + std::thread writer(consumer); + std::thread t1(producer); + std::thread t2(producer); + t1.join(); + t2.join(); + { + std::lock_guard lock(mutex); + done = true; + } + cv.notify_all(); + writer.join(); + + stream.finalize("tracks.trx", ZIP_CM_STORE); + +Process-based sharding and merge +-------------------------------- + +For large tractograms it is common to generate streamlines in separate +processes, write shard outputs, and merge them later. ``TrxStream`` provides two +finalization methods for directory output: + +- ``finalize_directory()`` — Single-process variant that removes any existing + directory before writing. Use when you control the entire lifecycle. + +- ``finalize_directory_persistent()`` — Multiprocess-safe variant that does NOT + remove existing directories. Use when coordinating parallel writes where a + parent process may pre-create output directories. + +Recommended multiprocess pattern: + +1. **Parent** pre-creates shard directories to validate filesystem writability. +2. Each **child process** writes a directory shard using + ``finalize_directory_persistent()``. +3. After finalization completes, child writes a sentinel file (e.g., ``SHARD_OK``) + to signal completion. +4. **Parent** waits for all ``SHARD_OK`` markers before merging shards. + +This pattern avoids race conditions where the parent checks for directory +existence while children are still writing. + +.. code-block:: cpp + + // Parent process: pre-create shard directories + for (size_t i = 0; i < num_shards; ++i) { + const std::string shard_path = "shards/shard_" + std::to_string(i); + std::filesystem::create_directories(shard_path); + } + + // Fork child processes... + +.. code-block:: cpp + + // Child process: write to pre-created directory + trx::TrxStream stream("float16"); + // ... push streamlines, dpv, dps, groups ... + stream.finalize_directory_persistent("/path/to/shards/shard_0"); + + // Signal completion to parent + std::ofstream ok("/path/to/shards/shard_0/SHARD_OK"); + ok << "ok\n"; + ok.close(); + +.. code-block:: cpp + + // Parent process (after waiting for all SHARD_OK markers) + // Merge by concatenating positions/DPV/DPS, adjusting offsets/groups. + // See bench/bench_trx_stream.cpp for a reference merge implementation. + +.. note:: + Use ``finalize_directory()`` for single-process writes where you want to + ensure a clean output state. Use ``finalize_directory_persistent()`` for + multiprocess workflows to avoid removing directories that may be checked + for existence by other processes. + +MRtrix-style write kernel (single-writer) +----------------------------------------- + +MRtrix uses a multi-threaded producer stage and a single-writer kernel to +serialize streamlines to disk. The same pattern works for TRX by letting the +writer own the ``TrxStream`` and accepting batches from the thread queue. + +.. code-block:: cpp + + #include + #include + #include + + struct TrxWriteKernel { + explicit TrxWriteKernel(const std::string &path) + : stream("float16"), out_path(path) {} + + void operator()(const std::vector>> &batch) { + for (const auto &points : batch) { + stream.push_streamline(points); + } + } + + void finalize() { + stream.finalize(out_path, ZIP_CM_STORE); + } + + private: + trx::TrxStream stream; + std::string out_path; + }; + +This kernel can be used as the final stage of a producer pipeline. The key rule +is: **only the writer thread touches ``TrxStream``**, while worker threads only +generate streamlines. + Optional NIfTI header support ----------------------------- diff --git a/include/trx/trx.h b/include/trx/trx.h index 03ab146..34c3bde 100644 --- a/include/trx/trx.h +++ b/include/trx/trx.h @@ -9,18 +9,24 @@ #include #include #include +#include +#include #include #include #include +#include #include #include #include #include #include #include +#include #include +#include #include #include +#include #include #include @@ -33,6 +39,15 @@ namespace fs = std::filesystem; using json = json11::Json; namespace trx { +enum class TrxSaveMode { Auto, Archive, Directory }; + +struct TrxSaveOptions { + zip_uint32_t compression_standard = ZIP_CM_STORE; + TrxSaveMode mode = TrxSaveMode::Auto; + size_t memory_limit_bytes = 0; // Reserved for future save-path tuning. + bool overwrite_existing = true; +}; + inline json::object _json_object(const json &value) { if (value.is_object()) { return value.object_items(); @@ -110,55 +125,68 @@ inline zip_t *open_zip_for_read(const std::string &path, int &errorp) { } template struct DTypeName { - static constexpr std::string_view value() { return "float16"; } + static constexpr bool supported = false; + static constexpr std::string_view value() { return ""; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "float16"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "float32"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "float64"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "int8"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "int16"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "int32"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "int64"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "uint8"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "uint16"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "uint32"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "uint64"; } }; template inline std::string dtype_from_scalar() { - typedef typename std::remove_cv::type>::type CleanT; + using CleanT = std::remove_cv_t>; + static_assert(DTypeName::supported, "Unsupported dtype for TRX scalar."); return std::string(DTypeName::value()); } @@ -236,6 +264,7 @@ template class TrxFile { std::string root = ""); template friend class TrxReader; + template friend std::unique_ptr> load(const std::string &path); /** * @brief Create a deepcopy of the TrxFile @@ -260,6 +289,14 @@ template class TrxFile { * @param compression_standard The compression standard to use, as defined by libzip (default: no compression) */ void save(const std::string &filename, zip_uint32_t compression_standard = ZIP_CM_STORE); + void save(const std::string &filename, const TrxSaveOptions &options); + /** + * @brief Normalize in-memory arrays for deterministic save semantics. + * + * This trims trailing preallocated rows when detected, rewrites lengths from + * offsets, and synchronizes header counts with the actual payload. + */ + void normalize_for_save(); void add_dps_from_text(const std::string &name, const std::string &dtype, const std::string &path); template @@ -332,6 +369,8 @@ template class TrxFile { * Each entry is {min_x, min_y, min_z, max_x, max_y, max_z} in TRX coordinates. */ std::vector> build_streamline_aabbs() const; + const std::vector> &get_or_build_streamline_aabbs() const; + void invalidate_aabb_cache() const; /** * @brief Extract a subset of streamlines intersecting an axis-aligned box. @@ -363,6 +402,11 @@ template class TrxFile { subset_streamlines(const std::vector &streamline_ids, bool build_cache_for_result = false) const; + const MMappedMatrix
*get_dps(const std::string &name) const; + const ArraySequence
*get_dpv(const std::string &name) const; + std::vector> get_streamline(size_t streamline_index) const; + template void for_each_streamline(Fn &&fn) const; + /** * @brief Add a data-per-group (DPG) field from a flat vector. * @@ -412,7 +456,6 @@ template class TrxFile { void remove_dpg_group(const std::string &group); private: - void invalidate_aabb_cache() const; mutable std::vector> aabb_cache_; /** * @brief Load a TrxFile from a zip archive. @@ -550,6 +593,8 @@ struct TypedArray { } }; +enum class TrxScalarType; + class AnyTrxFile { public: AnyTrxFile() = default; @@ -575,6 +620,19 @@ class AnyTrxFile { size_t num_streamlines() const; void close(); void save(const std::string &filename, zip_uint32_t compression_standard = ZIP_CM_STORE); + void save(const std::string &filename, const TrxSaveOptions &options); + + const TypedArray *get_dps(const std::string &name) const; + const TypedArray *get_dpv(const std::string &name) const; + std::vector> get_streamline(size_t streamline_index) const; + + using PositionsChunkCallback = + std::function; + using PositionsChunkMutableCallback = + std::function; + + void for_each_positions_chunk(size_t chunk_bytes, const PositionsChunkCallback &fn) const; + void for_each_positions_chunk_mutable(size_t chunk_bytes, const PositionsChunkMutableCallback &fn); static AnyTrxFile load(const std::string &path); static AnyTrxFile load_from_zip(const std::string &path); @@ -628,6 +686,42 @@ class TrxStream { */ void push_streamline(const std::vector> &points); + /** + * @brief Set max in-memory position buffer size (bytes). + * + * When set to a non-zero value, positions are buffered in memory and flushed + * to the temp file once the buffer reaches this size. Useful for reducing + * small I/O writes on slow disks. + */ + void set_positions_buffer_max_bytes(std::size_t max_bytes); + + enum class MetadataMode { InMemory, OnDisk }; + + /** + * @brief Control how DPS/DPV/groups are stored during streaming. + * + * InMemory keeps metadata in RAM until finalize (default). + * OnDisk writes metadata to temp files and copies them at finalize. + */ + void set_metadata_mode(MetadataMode mode); + + /** + * @brief Set max in-memory buffer size for metadata writes (bytes). + * + * Applies when MetadataMode::OnDisk. Larger buffers reduce write calls. + */ + void set_metadata_buffer_max_bytes(std::size_t max_bytes); + + /** + * @brief Set the VOXEL_TO_RASMM affine matrix in the header. + */ + void set_voxel_to_rasmm(const Eigen::Matrix4f &affine); + + /** + * @brief Set DIMENSIONS in the header. + */ + void set_dimensions(const std::array &dims); + /** * @brief Add per-streamline values (DPS) from an in-memory vector. */ @@ -647,6 +741,58 @@ class TrxStream { * @brief Finalize and write a TRX file. */ template void finalize(const std::string &filename, zip_uint32_t compression_standard = ZIP_CM_STORE); + void finalize(const std::string &filename, + TrxScalarType output_dtype, + zip_uint32_t compression_standard = ZIP_CM_STORE); + void finalize(const std::string &filename, const TrxSaveOptions &options); + + /** + * @brief Finalize and write a TRX directory (no zip). + * + * This method removes any existing directory at the output path before + * writing. Use this for single-process writes or when you control the + * entire output location lifecycle. + * + * @param directory Path where the uncompressed TRX directory will be created. + * + * @throws std::runtime_error if already finalized or if I/O fails. + * + * @see finalize_directory_persistent for multiprocess-safe variant. + */ + void finalize_directory(const std::string &directory); + + /** + * @brief Finalize and write a TRX directory without removing existing files. + * + * This variant is designed for multiprocess workflows where the output + * directory is pre-created by a parent process. Unlike finalize_directory(), + * this method does NOT remove the output directory if it exists, making it + * safe for coordinated parallel writes where multiple processes may check + * for the directory's existence. + * + * @param directory Path where the uncompressed TRX directory will be created. + * If the directory exists, its contents will be overwritten + * but the directory itself will not be removed and recreated. + * + * @throws std::runtime_error if already finalized or if I/O fails. + * + * @note Typical usage pattern: + * @code + * // Parent process creates shard directories + * fs::create_directories("shards/shard_0"); + * + * // Child process writes without removing directory + * trx::TrxStream stream("float16"); + * // ... push streamlines ... + * stream.finalize_directory_persistent("shards/shard_0"); + * std::ofstream("shards/shard_0/SHARD_OK") << "ok\n"; + * + * // Parent waits for SHARD_OK before reading results + * @endcode + * + * @see finalize_directory for single-process variant that ensures clean slate. + */ + void finalize_directory_persistent(const std::string &directory); size_t num_streamlines() const { return lengths_.size(); } size_t num_vertices() const { return total_vertices_; } @@ -659,13 +805,24 @@ class TrxStream { std::vector values; }; + struct MetadataFile { + std::string relative_path; + std::string absolute_path; + }; + void ensure_positions_stream(); + void flush_positions_buffer(); void cleanup_tmp(); + void ensure_metadata_dir(const std::string &subdir); + void finalize_directory_impl(const std::string &directory, bool remove_existing); std::string positions_dtype_; std::string tmp_dir_; std::string positions_path_; std::ofstream positions_out_; + std::vector positions_buffer_float_; + std::vector positions_buffer_half_; + std::size_t positions_buffer_max_entries_ = 0; std::vector lengths_; size_t total_vertices_ = 0; bool finalized_ = false; @@ -673,6 +830,9 @@ class TrxStream { std::map> groups_; std::map dps_; std::map dpv_; + MetadataMode metadata_mode_ = MetadataMode::InMemory; + std::vector metadata_files_; + std::size_t metadata_buffer_max_bytes_ = 8 * 1024 * 1024; }; /** @@ -740,6 +900,38 @@ inline std::string scalar_type_name(TrxScalarType dtype) { } } +struct PositionsOutputInfo { + std::string directory; + std::string positions_path; + std::string dtype; + size_t points = 0; +}; + +struct PrepareOutputOptions { + bool overwrite_existing = true; +}; + +/** + * @brief Prepare an output directory with copied metadata and offsets. + * + * Creates a new TRX directory (no zip) that contains header, offsets, and + * metadata (groups, dps, dpv, dpg), and returns where the positions file + * should be written. + */ +PositionsOutputInfo prepare_positions_output(const AnyTrxFile &input, + const std::string &output_directory, + const PrepareOutputOptions &options = {}); + +struct MergeTrxShardsOptions { + std::vector shard_directories; + std::string output_path; + zip_uint32_t compression_standard = ZIP_CM_STORE; + bool output_directory = false; + bool overwrite_existing = true; +}; + +void merge_trx_shards(const MergeTrxShardsOptions &options); + /** * @brief Detect the positions scalar type for a TRX path. * @@ -877,7 +1069,8 @@ void ediff1d(Eigen::Matrix &lengths, void zip_from_folder(zip_t *zf, const std::string &root, const std::string &directory, - zip_uint32_t compression_standard = ZIP_CM_STORE); + zip_uint32_t compression_standard = ZIP_CM_STORE, + const std::unordered_set *skip = nullptr); std::string get_base(const std::string &delimiter, const std::string &str); std::string get_ext(const std::string &str); diff --git a/include/trx/trx.tpp b/include/trx/trx.tpp index f0c8079..fff2f05 100644 --- a/include/trx/trx.tpp +++ b/include/trx/trx.tpp @@ -169,7 +169,7 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve offsets_dtype = dtype_from_scalar(); lengths_dtype = dtype_from_scalar(); } else { - positions_dtype = dtype_from_scalar(); + positions_dtype = dtype_from_scalar
(); offsets_dtype = dtype_from_scalar(); lengths_dtype = dtype_from_scalar(); } @@ -181,8 +181,7 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve trx->streamlines = std::make_unique>(); trx->streamlines->mmap_pos = trx::_create_memmap(positions_filename, shape, "w+", positions_dtype); - // TODO: find a better way to get the dtype than using all these switch cases. Also refactor - // into function as per specifications, positions can only be floats + // TODO: find a better way to get the dtype than using all these switch cases. if (positions_dtype.compare("float16") == 0) { new (&(trx->streamlines->_data)) Map>( reinterpret_cast(trx->streamlines->mmap_pos.data()), std::get<0>(shape), std::get<1>(shape)); @@ -356,9 +355,12 @@ TrxFile
::_create_trx_from_pointer(json header, long long size = std::get<1>(x->second); if (base.compare("positions") == 0 && (folder.compare("") == 0 || folder.compare(".") == 0)) { - if (size != static_cast(trx->header["NB_VERTICES"].int_value()) * 3 || dim != 3) { - - throw std::invalid_argument("Wrong data size/dimensionality"); + const auto nb_vertices = static_cast(trx->header["NB_VERTICES"].int_value()); + const auto expected = nb_vertices * 3; + if (size != expected || dim != 3) { + throw std::invalid_argument("Wrong data size/dimensionality: size=" + std::to_string(size) + + " expected=" + std::to_string(expected) + " dim=" + std::to_string(dim) + + " filename=" + elem_filename); } std::tuple shape = std::make_tuple(static_cast(trx->header["NB_VERTICES"].int_value()), 3); @@ -380,11 +382,12 @@ TrxFile
::_create_trx_from_pointer(json header, } else if (base.compare("offsets") == 0 && (folder.compare("") == 0 || folder.compare(".") == 0)) { - if (size != static_cast(trx->header["NB_STREAMLINES"].int_value()) + 1 || dim != 1) { - throw std::invalid_argument( - "Wrong offsets size/dimensionality: size=" + std::to_string(size) + - " nb_streamlines=" + std::to_string(static_cast(trx->header["NB_STREAMLINES"].int_value())) + - " dim=" + std::to_string(dim) + " filename=" + elem_filename); + const auto nb_streamlines = static_cast(trx->header["NB_STREAMLINES"].int_value()); + const auto expected = nb_streamlines + 1; + if (size != expected || dim != 1) { + throw std::invalid_argument("Wrong offsets size/dimensionality: size=" + std::to_string(size) + + " expected=" + std::to_string(expected) + " dim=" + std::to_string(dim) + + " filename=" + elem_filename); } const int nb_str = static_cast(trx->header["NB_STREAMLINES"].int_value()); @@ -965,9 +968,42 @@ template std::unique_ptr> TrxFile
::load_from_direc std::string header_name = directory + SEPARATOR + "header.json"; // TODO: add check to verify that it's open - std::ifstream header_file(header_name); + std::ifstream header_file; + for (int attempt = 0; attempt < 5; ++attempt) { + header_file.open(header_name); + if (header_file.is_open()) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } if (!header_file.is_open()) { - throw std::runtime_error("Failed to open header.json at: " + header_name); + std::error_code ec; + const bool exists = trx::fs::exists(directory, ec); + const int open_err = errno; + std::string detail = "Failed to open header.json at: " + header_name; + detail += " exists=" + std::string(exists ? "true" : "false"); + detail += " errno=" + std::to_string(open_err) + " msg=" + std::string(std::strerror(open_err)); + if (exists) { + std::vector files; + for (const auto &entry : trx::fs::directory_iterator(directory, ec)) { + if (ec) { + break; + } + files.push_back(entry.path().filename().string()); + } + if (!files.empty()) { + std::sort(files.begin(), files.end()); + detail += " files=["; + for (size_t i = 0; i < files.size(); ++i) { + if (i > 0) { + detail += ","; + } + detail += files[i]; + } + detail += "]"; + } + } + throw std::runtime_error(detail); } std::string jstream((std::istreambuf_iterator(header_file)), std::istreambuf_iterator()); header_file.close(); @@ -980,7 +1016,10 @@ template std::unique_ptr> TrxFile
::load_from_direc std::map> files_pointer_size; populate_fps(directory, files_pointer_size); - return TrxFile
::_create_trx_from_pointer(header, files_pointer_size, "", directory); + auto trx = TrxFile
::_create_trx_from_pointer(header, files_pointer_size, "", directory); + trx->_uncompressed_folder_handle = directory; + trx->_owns_uncompressed_folder = false; + return trx; } template std::unique_ptr> TrxFile
::load(const std::string &path) { @@ -995,6 +1034,10 @@ template std::unique_ptr> TrxFile
::load(const std: return TrxFile
::load_from_zip(path); } +template std::unique_ptr> load(const std::string &path) { + return TrxFile
::load(path); +} + template TrxReader
::TrxReader(const std::string &path) { trx_ = TrxFile
::load(path); } template TrxReader
::TrxReader(TrxReader &&other) noexcept : trx_(std::move(other.trx_)) {} @@ -1028,52 +1071,161 @@ auto with_trx_reader(const std::string &path, Fn &&fn) } template void TrxFile
::save(const std::string &filename, zip_uint32_t compression_standard) { + TrxSaveOptions options; + options.compression_standard = compression_standard; + save(filename, options); +} + +template void TrxFile
::normalize_for_save() { + if (!this->streamlines) { + throw std::runtime_error("Cannot normalize TRX without streamline data"); + } + if (this->streamlines->_offsets.size() == 0) { + throw std::runtime_error("Cannot normalize TRX without offsets data"); + } + + const size_t offsets_count = static_cast(this->streamlines->_offsets.size()); + if (offsets_count < 1) { + throw std::runtime_error("Invalid offsets array"); + } + const size_t total_streamlines = offsets_count - 1; + const uint64_t data_rows = static_cast(this->streamlines->_data.rows()); + + size_t used_streamlines = total_streamlines; + for (size_t i = 1; i < offsets_count; ++i) { + const uint64_t prev = static_cast(this->streamlines->_offsets(static_cast(i - 1))); + const uint64_t curr = static_cast(this->streamlines->_offsets(static_cast(i))); + if (curr < prev || curr > data_rows) { + used_streamlines = i - 1; + break; + } + } + + const uint64_t used_vertices = + static_cast(this->streamlines->_offsets(static_cast(used_streamlines))); + if (used_vertices > data_rows) { + throw std::runtime_error("TRX offsets exceed positions row count"); + } + if (used_vertices > static_cast(std::numeric_limits::max()) || + used_streamlines > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("TRX normalize_for_save exceeds supported int range"); + } + + if (used_streamlines < total_streamlines || used_vertices < data_rows) { + this->resize(static_cast(used_streamlines), static_cast(used_vertices)); + } + + const size_t normalized_streamlines = static_cast(this->streamlines->_offsets.size()) - 1; + for (size_t i = 0; i < normalized_streamlines; ++i) { + const uint64_t curr = static_cast(this->streamlines->_offsets(static_cast(i))); + const uint64_t next = static_cast(this->streamlines->_offsets(static_cast(i + 1))); + if (next < curr) { + throw std::runtime_error("TRX offsets must be monotonically increasing"); + } + const uint64_t diff = next - curr; + if (diff > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("TRX streamline length exceeds uint32 range"); + } + this->streamlines->_lengths(static_cast(i)) = static_cast(diff); + } + + const uint64_t sentinel = static_cast( + this->streamlines->_offsets(static_cast(this->streamlines->_offsets.size() - 1))); + this->header = _json_set(this->header, "NB_STREAMLINES", static_cast(normalized_streamlines)); + this->header = _json_set(this->header, "NB_VERTICES", static_cast(sentinel)); +} + +template void TrxFile
::save(const std::string &filename, const TrxSaveOptions &options) { std::string ext = get_ext(filename); - if (ext.size() > 0 && (ext != "zip" && ext != "trx")) { - throw std::invalid_argument("Unsupported extension." + ext); + if (ext.size() > 0 && ext != "zip" && ext != "trx") { + throw std::invalid_argument("Unsupported extension: " + ext); } - auto copy_trx = this->deepcopy(); - copy_trx->resize(); - if (!copy_trx->streamlines || copy_trx->streamlines->_offsets.size() == 0) { + TrxFile
*save_trx = this; + + if (!save_trx->streamlines || save_trx->streamlines->_offsets.size() == 0) { throw std::runtime_error("Cannot save TRX without offsets data"); } - if (copy_trx->header["NB_STREAMLINES"].is_number()) { - const auto nb_streamlines = static_cast(copy_trx->header["NB_STREAMLINES"].int_value()); - if (copy_trx->streamlines->_offsets.size() != static_cast(nb_streamlines + 1)) { + if (save_trx->header["NB_STREAMLINES"].is_number()) { + const auto nb_streamlines = static_cast(save_trx->header["NB_STREAMLINES"].int_value()); + if (save_trx->streamlines->_offsets.size() != static_cast(nb_streamlines + 1)) { throw std::runtime_error("TRX offsets size does not match NB_STREAMLINES"); } } - if (copy_trx->header["NB_VERTICES"].is_number()) { - const auto nb_vertices = static_cast(copy_trx->header["NB_VERTICES"].int_value()); + if (save_trx->header["NB_VERTICES"].is_number()) { + const auto nb_vertices = static_cast(save_trx->header["NB_VERTICES"].int_value()); const auto last = - static_cast(copy_trx->streamlines->_offsets(copy_trx->streamlines->_offsets.size() - 1)); + static_cast(save_trx->streamlines->_offsets(save_trx->streamlines->_offsets.size() - 1)); if (last != nb_vertices) { throw std::runtime_error("TRX offsets sentinel does not match NB_VERTICES"); } } - for (Eigen::Index i = 1; i < copy_trx->streamlines->_offsets.size(); ++i) { - if (copy_trx->streamlines->_offsets(i) < copy_trx->streamlines->_offsets(i - 1)) { + for (Eigen::Index i = 1; i < save_trx->streamlines->_offsets.size(); ++i) { + if (save_trx->streamlines->_offsets(i) < save_trx->streamlines->_offsets(i - 1)) { throw std::runtime_error("TRX offsets must be monotonically increasing"); } } - if (copy_trx->streamlines->_data.size() > 0) { + if (save_trx->streamlines->_data.size() > 0) { const auto last = - static_cast(copy_trx->streamlines->_offsets(copy_trx->streamlines->_offsets.size() - 1)); - if (last != static_cast(copy_trx->streamlines->_data.rows())) { + static_cast(save_trx->streamlines->_offsets(save_trx->streamlines->_offsets.size() - 1)); + if (last != static_cast(save_trx->streamlines->_data.rows())) { throw std::runtime_error("TRX positions row count does not match offsets sentinel"); } } - std::string tmp_dir_name = copy_trx->_uncompressed_folder_handle; + std::string tmp_dir_name = save_trx->_uncompressed_folder_handle; + + if (!tmp_dir_name.empty()) { + const std::string header_path = tmp_dir_name + SEPARATOR + "header.json"; + std::ofstream out_json(header_path, std::ios::out | std::ios::trunc); + if (!out_json.is_open()) { + throw std::runtime_error("Failed to write header.json to: " + header_path); + } + out_json << save_trx->header.dump() << std::endl; + out_json.close(); + } + + const bool write_archive = options.mode == TrxSaveMode::Archive || + (options.mode == TrxSaveMode::Auto && ext.size() > 0 && (ext == "zip" || ext == "trx")); + if (write_archive) { + auto sync_unmap_seq = [&](auto &seq) { + if (!seq) { + return; + } + std::error_code ec; + seq->mmap_pos.sync(ec); + seq->mmap_off.sync(ec); + }; + auto sync_unmap_mat = [&](auto &mat) { + if (!mat) { + return; + } + std::error_code ec; + mat->mmap.sync(ec); + }; + + sync_unmap_seq(save_trx->streamlines); + for (auto &kv : save_trx->groups) { + sync_unmap_mat(kv.second); + } + for (auto &kv : save_trx->data_per_streamline) { + sync_unmap_mat(kv.second); + } + for (auto &kv : save_trx->data_per_vertex) { + sync_unmap_seq(kv.second); + } + for (auto &group_kv : save_trx->data_per_group) { + for (auto &kv : group_kv.second) { + sync_unmap_mat(kv.second); + } + } - if (ext.size() > 0 && (ext == "zip" || ext == "trx")) { int errorp; zip_t *zf; if ((zf = zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)) == nullptr) { throw std::runtime_error("Could not open archive " + filename + ": " + strerror(errorp)); } else { - zip_from_folder(zf, tmp_dir_name, tmp_dir_name, compression_standard); + zip_from_folder(zf, tmp_dir_name, tmp_dir_name, options.compression_standard, nullptr); if (zip_close(zf) != 0) { throw std::runtime_error("Unable to close archive " + filename + ": " + zip_strerror(zf)); } @@ -1084,6 +1236,9 @@ template void TrxFile
::save(const std::string &filename, zip_u throw std::runtime_error("Temporary TRX directory does not exist: " + tmp_dir_name); } if (trx::fs::exists(filename, ec) && trx::fs::is_directory(filename, ec)) { + if (!options.overwrite_existing) { + throw std::runtime_error("Output directory already exists: " + filename); + } if (rm_dir(filename) != 0) { throw std::runtime_error("Could not remove existing directory " + filename); } @@ -1105,7 +1260,6 @@ template void TrxFile
::save(const std::string &filename, zip_u if (!trx::fs::exists(header_path)) { throw std::runtime_error("Missing header.json in output directory: " + header_path.string()); } - copy_trx->close(); } } @@ -1388,8 +1542,8 @@ inline TrxStream::TrxStream(std::string positions_dtype) : positions_dtype_(std: std::transform(positions_dtype_.begin(), positions_dtype_.end(), positions_dtype_.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); - if (positions_dtype_ != "float32") { - throw std::invalid_argument("TrxStream only supports float32 positions for now"); + if (positions_dtype_ != "float32" && positions_dtype_ != "float16") { + throw std::invalid_argument("TrxStream only supports float16/float32 positions for now"); } tmp_dir_ = make_temp_dir("trx_proto"); positions_path_ = tmp_dir_ + SEPARATOR + "positions.tmp"; @@ -1398,6 +1552,20 @@ inline TrxStream::TrxStream(std::string positions_dtype) : positions_dtype_(std: inline TrxStream::~TrxStream() { cleanup_tmp(); } +inline void TrxStream::set_metadata_mode(MetadataMode mode) { + if (finalized_) { + throw std::runtime_error("Cannot adjust metadata mode after finalize"); + } + metadata_mode_ = mode; +} + +inline void TrxStream::set_metadata_buffer_max_bytes(std::size_t max_bytes) { + if (finalized_) { + throw std::runtime_error("Cannot adjust metadata buffer after finalize"); + } + metadata_buffer_max_bytes_ = max_bytes; +} + inline void TrxStream::ensure_positions_stream() { if (!positions_out_.is_open()) { positions_out_.open(positions_path_, std::ios::binary | std::ios::out | std::ios::trunc); @@ -1407,7 +1575,50 @@ inline void TrxStream::ensure_positions_stream() { } } +inline void TrxStream::ensure_metadata_dir(const std::string &subdir) { + if (tmp_dir_.empty()) { + throw std::runtime_error("TrxStream temp directory not initialized"); + } + const std::string dir = tmp_dir_ + SEPARATOR + subdir + SEPARATOR; + std::error_code ec; + trx::fs::create_directories(dir, ec); + if (ec) { + throw std::runtime_error("Could not create directory " + dir); + } +} + +inline void TrxStream::flush_positions_buffer() { + if (positions_dtype_ == "float16") { + if (positions_buffer_half_.empty()) { + return; + } + ensure_positions_stream(); + const size_t byte_count = positions_buffer_half_.size() * sizeof(half); + positions_out_.write(reinterpret_cast(positions_buffer_half_.data()), + static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions buffer"); + } + positions_buffer_half_.clear(); + return; + } + + if (positions_buffer_float_.empty()) { + return; + } + ensure_positions_stream(); + const size_t byte_count = positions_buffer_float_.size() * sizeof(float); + positions_out_.write(reinterpret_cast(positions_buffer_float_.data()), + static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions buffer"); + } + positions_buffer_float_.clear(); +} + inline void TrxStream::cleanup_tmp() { + positions_buffer_float_.clear(); + positions_buffer_half_.clear(); if (positions_out_.is_open()) { positions_out_.close(); } @@ -1425,11 +1636,42 @@ inline void TrxStream::push_streamline(const float *xyz, size_t point_count) { lengths_.push_back(0); return; } - ensure_positions_stream(); - const size_t byte_count = point_count * 3 * sizeof(float); - positions_out_.write(reinterpret_cast(xyz), static_cast(byte_count)); - if (!positions_out_) { - throw std::runtime_error("Failed to write TrxStream positions"); + if (positions_buffer_max_entries_ == 0) { + ensure_positions_stream(); + if (positions_dtype_ == "float16") { + std::vector tmp; + tmp.reserve(point_count * 3); + for (size_t i = 0; i < point_count * 3; ++i) { + tmp.push_back(static_cast(xyz[i])); + } + const size_t byte_count = tmp.size() * sizeof(half); + positions_out_.write(reinterpret_cast(tmp.data()), static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions"); + } + } else { + const size_t byte_count = point_count * 3 * sizeof(float); + positions_out_.write(reinterpret_cast(xyz), static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions"); + } + } + } else { + const size_t floats_count = point_count * 3; + if (positions_dtype_ == "float16") { + positions_buffer_half_.reserve(positions_buffer_half_.size() + floats_count); + for (size_t i = 0; i < floats_count; ++i) { + positions_buffer_half_.push_back(static_cast(xyz[i])); + } + if (positions_buffer_half_.size() >= positions_buffer_max_entries_) { + flush_positions_buffer(); + } + } else { + positions_buffer_float_.insert(positions_buffer_float_.end(), xyz, xyz + floats_count); + if (positions_buffer_float_.size() >= positions_buffer_max_entries_) { + flush_positions_buffer(); + } + } } total_vertices_ += point_count; lengths_.push_back(static_cast(point_count)); @@ -1443,7 +1685,32 @@ inline void TrxStream::push_streamline(const std::vector &xyz_flat) { } inline void TrxStream::push_streamline(const std::vector> &points) { - push_streamline(reinterpret_cast(points.data()), points.size()); + if (points.empty()) { + push_streamline(static_cast(nullptr), 0); + return; + } + std::vector xyz_flat; + xyz_flat.reserve(points.size() * 3); + for (const auto &point : points) { + xyz_flat.push_back(point[0]); + xyz_flat.push_back(point[1]); + xyz_flat.push_back(point[2]); + } + push_streamline(xyz_flat); +} + +inline void TrxStream::set_voxel_to_rasmm(const Eigen::Matrix4f &affine) { + std::vector> matrix(4, std::vector(4, 0.0f)); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + matrix[static_cast(i)][static_cast(j)] = affine(i, j); + } + } + header = _json_set(header, "VOXEL_TO_RASMM", matrix); +} + +inline void TrxStream::set_dimensions(const std::array &dims) { + header = _json_set(header, "DIMENSIONS", std::vector{dims[0], dims[1], dims[2]}); } template @@ -1462,13 +1729,67 @@ TrxStream::push_dps_from_vector(const std::string &name, const std::string &dtyp if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { throw std::invalid_argument("Unsupported DPS dtype: " + dtype); } - FieldValues field; - field.dtype = dtype_norm; - field.values.reserve(values.size()); - for (const auto &v : values) { - field.values.push_back(static_cast(v)); + if (metadata_mode_ == MetadataMode::OnDisk) { + ensure_metadata_dir("dps"); + const std::string filename = tmp_dir_ + SEPARATOR + "dps" + SEPARATOR + name + "." + dtype_norm; + std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open DPS file: " + filename); + } + if (dtype_norm == "float16") { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(half)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(half))); + offset += count; + } + } else if (dtype_norm == "float32") { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(float)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(float))); + offset += count; + } + } else { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(double)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(double))); + offset += count; + } + } + out.close(); + metadata_files_.push_back({std::string("dps") + SEPARATOR + name + "." + dtype_norm, filename}); + } else { + FieldValues field; + field.dtype = dtype_norm; + field.values.reserve(values.size()); + for (const auto &v : values) { + field.values.push_back(static_cast(v)); + } + dps_[name] = std::move(field); } - dps_[name] = std::move(field); } template @@ -1487,20 +1808,113 @@ TrxStream::push_dpv_from_vector(const std::string &name, const std::string &dtyp if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { throw std::invalid_argument("Unsupported DPV dtype: " + dtype); } - FieldValues field; - field.dtype = dtype_norm; - field.values.reserve(values.size()); - for (const auto &v : values) { - field.values.push_back(static_cast(v)); + if (metadata_mode_ == MetadataMode::OnDisk) { + ensure_metadata_dir("dpv"); + const std::string filename = tmp_dir_ + SEPARATOR + "dpv" + SEPARATOR + name + "." + dtype_norm; + std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open DPV file: " + filename); + } + if (dtype_norm == "float16") { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(half)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(half))); + offset += count; + } + } else if (dtype_norm == "float32") { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(float)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(float))); + offset += count; + } + } else { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(double)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(double))); + offset += count; + } + } + out.close(); + metadata_files_.push_back({std::string("dpv") + SEPARATOR + name + "." + dtype_norm, filename}); + } else { + FieldValues field; + field.dtype = dtype_norm; + field.values.reserve(values.size()); + for (const auto &v : values) { + field.values.push_back(static_cast(v)); + } + dpv_[name] = std::move(field); + } +} + +inline void TrxStream::set_positions_buffer_max_bytes(std::size_t max_bytes) { + if (finalized_) { + throw std::runtime_error("Cannot adjust buffer after finalize"); + } + if (max_bytes == 0) { + positions_buffer_max_entries_ = 0; + positions_buffer_float_.clear(); + positions_buffer_half_.clear(); + return; + } + const std::size_t element_size = positions_dtype_ == "float16" ? sizeof(half) : sizeof(float); + const std::size_t entries = max_bytes / element_size; + const std::size_t aligned = (entries / 3) * 3; + positions_buffer_max_entries_ = aligned; + if (positions_buffer_max_entries_ == 0) { + positions_buffer_float_.clear(); + positions_buffer_half_.clear(); } - dpv_[name] = std::move(field); } inline void TrxStream::push_group_from_indices(const std::string &name, const std::vector &indices) { if (name.empty()) { throw std::invalid_argument("Group name cannot be empty"); } - groups_[name] = indices; + if (metadata_mode_ == MetadataMode::OnDisk) { + ensure_metadata_dir("groups"); + const std::string filename = tmp_dir_ + SEPARATOR + "groups" + SEPARATOR + name + ".uint32"; + std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open group file: " + filename); + } + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(uint32_t)); + size_t offset = 0; + while (offset < indices.size()) { + const size_t count = std::min(chunk_elems, indices.size() - offset); + out.write(reinterpret_cast(indices.data() + offset), + static_cast(count * sizeof(uint32_t))); + offset += count; + } + out.close(); + metadata_files_.push_back({std::string("groups") + SEPARATOR + name + ".uint32", filename}); + } else { + groups_[name] = indices; + } } template void TrxStream::finalize(const std::string &filename, zip_uint32_t compression_standard) { @@ -1509,6 +1923,7 @@ template void TrxStream::finalize(const std::string &filename, zip } finalized_ = true; + flush_positions_buffer(); if (positions_out_.is_open()) { positions_out_.flush(); positions_out_.close(); @@ -1539,14 +1954,25 @@ template void TrxStream::finalize(const std::string &filename, zip throw std::runtime_error("Failed to open TrxStream temp positions file for read: " + positions_path_); } for (size_t i = 0; i < nb_vertices; ++i) { - float xyz[3]; - in.read(reinterpret_cast(xyz), sizeof(xyz)); - if (!in) { - throw std::runtime_error("Failed to read TrxStream positions"); + if (positions_dtype_ == "float16") { + half xyz[3]; + in.read(reinterpret_cast(xyz), sizeof(xyz)); + if (!in) { + throw std::runtime_error("Failed to read TrxStream positions"); + } + positions(static_cast(i), 0) = static_cast
(xyz[0]); + positions(static_cast(i), 1) = static_cast
(xyz[1]); + positions(static_cast(i), 2) = static_cast
(xyz[2]); + } else { + float xyz[3]; + in.read(reinterpret_cast(xyz), sizeof(xyz)); + if (!in) { + throw std::runtime_error("Failed to read TrxStream positions"); + } + positions(static_cast(i), 0) = static_cast
(xyz[0]); + positions(static_cast(i), 1) = static_cast
(xyz[1]); + positions(static_cast(i), 2) = static_cast
(xyz[2]); } - positions(static_cast(i), 0) = static_cast
(xyz[0]); - positions(static_cast(i), 1) = static_cast
(xyz[1]); - positions(static_cast(i), 2) = static_cast
(xyz[2]); } for (const auto &kv : dps_) { @@ -1559,12 +1985,244 @@ template void TrxStream::finalize(const std::string &filename, zip trx.add_group_from_indices(kv.first, kv.second); } + if (metadata_mode_ == MetadataMode::OnDisk) { + for (const auto &meta : metadata_files_) { + const std::string dest = trx._uncompressed_folder_handle + SEPARATOR + meta.relative_path; + const trx::fs::path dest_path(dest); + if (dest_path.has_parent_path()) { + std::error_code parent_ec; + trx::fs::create_directories(dest_path.parent_path(), parent_ec); + } + std::error_code copy_ec; + trx::fs::copy_file(meta.absolute_path, dest, trx::fs::copy_options::overwrite_existing, copy_ec); + if (copy_ec) { + throw std::runtime_error("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest); + } + } + } + trx.save(filename, compression_standard); trx.close(); cleanup_tmp(); } +inline void TrxStream::finalize(const std::string &filename, + TrxScalarType output_dtype, + zip_uint32_t compression_standard) { + switch (output_dtype) { + case TrxScalarType::Float16: + finalize(filename, compression_standard); + break; + case TrxScalarType::Float64: + finalize(filename, compression_standard); + break; + case TrxScalarType::Float32: + default: + finalize(filename, compression_standard); + break; + } +} + +inline void TrxStream::finalize(const std::string &filename, const TrxSaveOptions &options) { + if (options.mode == TrxSaveMode::Directory) { + if (finalized_) { + throw std::runtime_error("TrxStream already finalized"); + } + if (options.overwrite_existing) { + finalize_directory(filename); + } else { + finalize_directory_persistent(filename); + } + return; + } + + TrxScalarType out_type = TrxScalarType::Float32; + if (positions_dtype_ == "float16") { + out_type = TrxScalarType::Float16; + } else if (positions_dtype_ == "float64") { + out_type = TrxScalarType::Float64; + } + finalize(filename, out_type, options.compression_standard); +} + +inline void TrxStream::finalize_directory_impl(const std::string &directory, bool remove_existing) { + if (finalized_) { + throw std::runtime_error("TrxStream already finalized"); + } + finalized_ = true; + + flush_positions_buffer(); + if (positions_out_.is_open()) { + positions_out_.flush(); + positions_out_.close(); + } + + const size_t nb_streamlines = lengths_.size(); + const size_t nb_vertices = total_vertices_; + + std::error_code ec; + if (remove_existing && trx::fs::exists(directory, ec)) { + trx::fs::remove_all(directory, ec); + ec.clear(); + } + + // Create directory if it doesn't exist + if (!trx::fs::exists(directory, ec)) { + trx::fs::create_directories(directory, ec); + if (ec) { + throw std::runtime_error("Failed to create output directory: " + directory); + } + } + ec.clear(); + + json header_out = header; + header_out = _json_set(header_out, "NB_VERTICES", static_cast(nb_vertices)); + header_out = _json_set(header_out, "NB_STREAMLINES", static_cast(nb_streamlines)); + const std::string header_path = directory + SEPARATOR + "header.json"; + std::ofstream out_header(header_path, std::ios::out | std::ios::trunc); + if (!out_header.is_open()) { + throw std::runtime_error("Failed to write header.json to: " + header_path); + } + out_header << header_out.dump() << std::endl; + out_header.close(); + + const std::string positions_name = "positions.3." + positions_dtype_; + const std::string positions_dst = directory + SEPARATOR + positions_name; + trx::fs::rename(positions_path_, positions_dst, ec); + if (ec) { + ec.clear(); + trx::fs::copy_file(positions_path_, positions_dst, trx::fs::copy_options::overwrite_existing, ec); + if (ec) { + throw std::runtime_error("Failed to copy positions file to: " + positions_dst); + } + } + + const std::string offsets_dst = directory + SEPARATOR + "offsets.uint64"; + std::ofstream offsets_out(offsets_dst, std::ios::binary | std::ios::out | std::ios::trunc); + if (!offsets_out.is_open()) { + throw std::runtime_error("Failed to open offsets file for write: " + offsets_dst); + } + uint64_t offset = 0; + offsets_out.write(reinterpret_cast(&offset), sizeof(offset)); + for (const auto length : lengths_) { + offset += static_cast(length); + offsets_out.write(reinterpret_cast(&offset), sizeof(offset)); + } + offsets_out.flush(); + offsets_out.close(); + + auto write_field_values = [&](const std::string &path, const FieldValues &values) { + std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open metadata file: " + path); + } + const size_t count = values.values.size(); + if (values.dtype == "float16") { + const size_t chunk = std::max(1, metadata_buffer_max_bytes_ / sizeof(half)); + std::vector tmp; + tmp.reserve(chunk); + size_t idx = 0; + while (idx < count) { + const size_t n = std::min(chunk, count - idx); + tmp.clear(); + for (size_t i = 0; i < n; ++i) { + tmp.push_back(static_cast(values.values[idx + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(n * sizeof(half))); + idx += n; + } + } else if (values.dtype == "float32") { + const size_t chunk = std::max(1, metadata_buffer_max_bytes_ / sizeof(float)); + std::vector tmp; + tmp.reserve(chunk); + size_t idx = 0; + while (idx < count) { + const size_t n = std::min(chunk, count - idx); + tmp.clear(); + for (size_t i = 0; i < n; ++i) { + tmp.push_back(static_cast(values.values[idx + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(n * sizeof(float))); + idx += n; + } + } else if (values.dtype == "float64") { + const size_t chunk = std::max(1, metadata_buffer_max_bytes_ / sizeof(double)); + std::vector tmp; + tmp.reserve(chunk); + size_t idx = 0; + while (idx < count) { + const size_t n = std::min(chunk, count - idx); + tmp.clear(); + for (size_t i = 0; i < n; ++i) { + tmp.push_back(values.values[idx + i]); + } + out.write(reinterpret_cast(tmp.data()), static_cast(n * sizeof(double))); + idx += n; + } + } else { + throw std::runtime_error("Unsupported metadata dtype: " + values.dtype); + } + out.close(); + }; + + if (metadata_mode_ == MetadataMode::OnDisk) { + for (const auto &meta : metadata_files_) { + const std::string dest = directory + SEPARATOR + meta.relative_path; + const trx::fs::path dest_path(dest); + if (dest_path.has_parent_path()) { + std::error_code parent_ec; + trx::fs::create_directories(dest_path.parent_path(), parent_ec); + } + std::error_code copy_ec; + trx::fs::copy_file(meta.absolute_path, dest, trx::fs::copy_options::overwrite_existing, copy_ec); + if (copy_ec) { + throw std::runtime_error("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest); + } + } + } else { + if (!dps_.empty()) { + trx::fs::create_directories(directory + SEPARATOR + "dps", ec); + for (const auto &kv : dps_) { + const std::string path = directory + SEPARATOR + "dps" + SEPARATOR + kv.first + "." + kv.second.dtype; + write_field_values(path, kv.second); + } + } + if (!dpv_.empty()) { + trx::fs::create_directories(directory + SEPARATOR + "dpv", ec); + for (const auto &kv : dpv_) { + const std::string path = directory + SEPARATOR + "dpv" + SEPARATOR + kv.first + "." + kv.second.dtype; + write_field_values(path, kv.second); + } + } + if (!groups_.empty()) { + trx::fs::create_directories(directory + SEPARATOR + "groups", ec); + for (const auto &kv : groups_) { + const std::string path = directory + SEPARATOR + "groups" + SEPARATOR + kv.first + ".uint32"; + std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open group file: " + path); + } + if (!kv.second.empty()) { + out.write(reinterpret_cast(kv.second.data()), + static_cast(kv.second.size() * sizeof(uint32_t))); + } + out.close(); + } + } + } + + cleanup_tmp(); +} + +inline void TrxStream::finalize_directory(const std::string &directory) { + finalize_directory_impl(directory, true); +} + +inline void TrxStream::finalize_directory_persistent(const std::string &directory) { + finalize_directory_impl(directory, false); +} + template void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &dtype, const std::string &path) { if (name.empty()) { @@ -2050,6 +2708,14 @@ std::vector> TrxFile
::build_streamline_aabbs() co return aabbs; } +template +const std::vector> &TrxFile
::get_or_build_streamline_aabbs() const { + if (this->aabb_cache_.empty()) { + this->build_streamline_aabbs(); + } + return this->aabb_cache_; +} + template std::unique_ptr> TrxFile
::query_aabb( const std::array &min_corner, @@ -2117,6 +2783,64 @@ void TrxFile
::invalidate_aabb_cache() const { this->aabb_cache_.clear(); } +template +const MMappedMatrix
*TrxFile
::get_dps(const std::string &name) const { + auto it = this->data_per_streamline.find(name); + if (it == this->data_per_streamline.end()) { + return nullptr; + } + return it->second.get(); +} + +template +const ArraySequence
*TrxFile
::get_dpv(const std::string &name) const { + auto it = this->data_per_vertex.find(name); + if (it == this->data_per_vertex.end()) { + return nullptr; + } + return it->second.get(); +} + +template +std::vector> TrxFile
::get_streamline(size_t streamline_index) const { + if (!this->streamlines || this->streamlines->_offsets.size() == 0) { + throw std::runtime_error("TRX streamlines are not available"); + } + const size_t n_streamlines = static_cast(this->streamlines->_offsets.size() - 1); + if (streamline_index >= n_streamlines) { + throw std::out_of_range("Streamline index out of range"); + } + + const uint64_t start = static_cast(this->streamlines->_offsets(static_cast(streamline_index), 0)); + const uint64_t end = + static_cast(this->streamlines->_offsets(static_cast(streamline_index + 1), 0)); + std::vector> points; + if (end <= start) { + return points; + } + points.reserve(static_cast(end - start)); + for (uint64_t i = start; i < end; ++i) { + points.push_back({this->streamlines->_data(static_cast(i), 0), + this->streamlines->_data(static_cast(i), 1), + this->streamlines->_data(static_cast(i), 2)}); + } + return points; +} + +template +template +void TrxFile
::for_each_streamline(Fn &&fn) const { + if (!this->streamlines || this->streamlines->_offsets.size() == 0) { + return; + } + const size_t n_streamlines = static_cast(this->streamlines->_offsets.size() - 1); + for (size_t i = 0; i < n_streamlines; ++i) { + const uint64_t start = static_cast(this->streamlines->_offsets(static_cast(i), 0)); + const uint64_t end = static_cast(this->streamlines->_offsets(static_cast(i + 1), 0)); + fn(i, start, end - start); + } +} + template template void TrxFile
::add_dpg_from_vector(const std::string &group, diff --git a/src/trx.cpp b/src/trx.cpp index 5762b34..67f0a58 100644 --- a/src/trx.cpp +++ b/src/trx.cpp @@ -16,12 +16,18 @@ #include #include #include +#include #include #include #include #include #include #include +#if defined(_WIN32) || defined(_WIN64) +#include +#else +#include +#endif #include #include @@ -88,6 +94,45 @@ bool is_path_within(const trx::fs::path &child, const trx::fs::path &parent) { return next == '/' || next == '\\'; } +TrxSaveMode resolve_save_mode(const std::string &filename, TrxSaveMode requested) { + if (requested != TrxSaveMode::Auto) { + return requested; + } + const std::string ext = get_ext(filename); + if (ext == "zip" || ext == "trx") { + return TrxSaveMode::Archive; + } + return TrxSaveMode::Directory; +} + +std::array read_xyz_as_double(const TypedArray &positions, size_t row_index) { + if (positions.cols != 3) { + throw std::runtime_error("Positions must have 3 columns."); + } + if (row_index >= static_cast(positions.rows)) { + throw std::out_of_range("Position row index out of range"); + } + if (positions.dtype == "float16") { + const auto view = positions.as_matrix(); + return {static_cast(view(static_cast(row_index), 0)), + static_cast(view(static_cast(row_index), 1)), + static_cast(view(static_cast(row_index), 2))}; + } + if (positions.dtype == "float32") { + const auto view = positions.as_matrix(); + return {static_cast(view(static_cast(row_index), 0)), + static_cast(view(static_cast(row_index), 1)), + static_cast(view(static_cast(row_index), 2))}; + } + if (positions.dtype == "float64") { + const auto view = positions.as_matrix(); + return {view(static_cast(row_index), 0), + view(static_cast(row_index), 1), + view(static_cast(row_index), 2)}; + } + throw std::runtime_error("Unsupported positions dtype for streamline extraction: " + positions.dtype); +} + TypedArray make_typed_array(const std::string &filename, int rows, int cols, const std::string &dtype) { TypedArray array; array.dtype = dtype; @@ -194,6 +239,40 @@ size_t AnyTrxFile::num_streamlines() const { return 0; } +const TypedArray *AnyTrxFile::get_dps(const std::string &name) const { + auto it = data_per_streamline.find(name); + if (it == data_per_streamline.end()) { + return nullptr; + } + return &it->second; +} + +const TypedArray *AnyTrxFile::get_dpv(const std::string &name) const { + auto it = data_per_vertex.find(name); + if (it == data_per_vertex.end()) { + return nullptr; + } + return &it->second; +} + +std::vector> AnyTrxFile::get_streamline(size_t streamline_index) const { + if (offsets_u64.empty()) { + throw std::runtime_error("TRX offsets are empty."); + } + const size_t n_streamlines = offsets_u64.size() - 1; + if (streamline_index >= n_streamlines) { + throw std::out_of_range("Streamline index out of range"); + } + const uint64_t start = offsets_u64[streamline_index]; + const uint64_t end = offsets_u64[streamline_index + 1]; + std::vector> out; + out.reserve(static_cast(end - start)); + for (uint64_t i = start; i < end; ++i) { + out.push_back(read_xyz_as_double(positions, static_cast(i))); + } + return out; +} + void AnyTrxFile::close() { _cleanup_temporary_directory(); positions = TypedArray(); @@ -268,9 +347,42 @@ AnyTrxFile AnyTrxFile::load_from_directory(const std::string &path) { } std::string header_name = directory + SEPARATOR + "header.json"; - std::ifstream header_file(header_name); + std::ifstream header_file; + for (int attempt = 0; attempt < 5; ++attempt) { + header_file.open(header_name); + if (header_file.is_open()) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } if (!header_file.is_open()) { - throw std::runtime_error("Failed to open header.json at: " + header_name); + std::error_code ec; + const bool exists = trx::fs::exists(directory, ec); + const int open_err = errno; + std::string detail = "Failed to open header.json at: " + header_name; + detail += " exists=" + std::string(exists ? "true" : "false"); + detail += " errno=" + std::to_string(open_err) + " msg=" + std::string(std::strerror(open_err)); + if (exists) { + std::vector files; + for (const auto &entry : trx::fs::directory_iterator(directory, ec)) { + if (ec) { + break; + } + files.push_back(entry.path().filename().string()); + } + if (!files.empty()) { + std::sort(files.begin(), files.end()); + detail += " files=["; + for (size_t i = 0; i < files.size(); ++i) { + if (i > 0) { + detail += ","; + } + detail += files[i]; + } + detail += "]"; + } + } + throw std::runtime_error(detail); } std::string jstream((std::istreambuf_iterator(header_file)), std::istreambuf_iterator()); header_file.close(); @@ -419,9 +531,16 @@ AnyTrxFile::_create_from_pointer(json header, } void AnyTrxFile::save(const std::string &filename, zip_uint32_t compression_standard) { + TrxSaveOptions options; + options.compression_standard = compression_standard; + save(filename, options); +} + +void AnyTrxFile::save(const std::string &filename, const TrxSaveOptions &options) { const std::string ext = get_ext(filename); - if (ext.size() > 0 && (ext != "zip" && ext != "trx")) { - throw std::invalid_argument("Unsupported extension." + ext); + const TrxSaveMode save_mode = resolve_save_mode(filename, options.mode); + if (ext.size() > 0 && ext != "zip" && ext != "trx") { + throw std::invalid_argument("Unsupported extension: " + ext); } if (offsets.empty()) { @@ -461,35 +580,44 @@ void AnyTrxFile::save(const std::string &filename, zip_uint32_t compression_stan throw std::runtime_error("TRX file has no backing directory to save from"); } - std::string tmp_dir = make_temp_dir("trx_runtime"); - copy_dir(source_dir, tmp_dir); - - { - const trx::fs::path header_path = trx::fs::path(tmp_dir) / "header.json"; - std::ofstream out_json(header_path); - if (!out_json.is_open()) { - throw std::runtime_error("Failed to write header.json to: " + header_path.string()); - } - out_json << header.dump() << std::endl; - } - - if (ext.size() > 0 && (ext == "zip" || ext == "trx")) { + if (save_mode == TrxSaveMode::Archive) { int errorp; zip_t *zf; if ((zf = zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)) == nullptr) { - rm_dir(tmp_dir); throw std::runtime_error("Could not open archive " + filename + ": " + strerror(errorp)); } - zip_from_folder(zf, tmp_dir, tmp_dir, compression_standard); + + const std::string header_payload = header.dump() + "\n"; + zip_source_t *header_source = + zip_source_buffer(zf, header_payload.data(), header_payload.size(), 0 /* do not free */); + if (header_source == nullptr) { + zip_close(zf); + throw std::runtime_error("Failed to create zip source for header.json: " + std::string(zip_strerror(zf))); + } + const zip_int64_t header_idx = zip_file_add(zf, "header.json", header_source, ZIP_FL_ENC_UTF_8 | ZIP_FL_OVERWRITE); + if (header_idx < 0) { + zip_source_free(header_source); + zip_close(zf); + throw std::runtime_error("Failed to add header.json to archive: " + std::string(zip_strerror(zf))); + } + const zip_int32_t compression = static_cast(options.compression_standard); + if (zip_set_file_compression(zf, header_idx, compression, 0) < 0) { + zip_close(zf); + throw std::runtime_error("Failed to set compression for header.json: " + std::string(zip_strerror(zf))); + } + + const std::unordered_set skip = {"header.json"}; + zip_from_folder(zf, source_dir, source_dir, options.compression_standard, &skip); if (zip_close(zf) != 0) { - rm_dir(tmp_dir); throw std::runtime_error("Unable to close archive " + filename + ": " + zip_strerror(zf)); } } else { std::error_code ec; if (trx::fs::exists(filename, ec) && trx::fs::is_directory(filename, ec)) { + if (!options.overwrite_existing) { + throw std::runtime_error("Output directory already exists: " + filename); + } if (rm_dir(filename) != 0) { - rm_dir(tmp_dir); throw std::runtime_error("Could not remove existing directory " + filename); } } @@ -498,24 +626,35 @@ void AnyTrxFile::save(const std::string &filename, zip_uint32_t compression_stan std::error_code parent_ec; trx::fs::create_directories(dest_path.parent_path(), parent_ec); if (parent_ec) { - rm_dir(tmp_dir); throw std::runtime_error("Could not create output parent directory: " + dest_path.parent_path().string()); } } - copy_dir(tmp_dir, filename); + std::error_code source_ec; + const trx::fs::path source_path = trx::fs::weakly_canonical(trx::fs::path(source_dir), source_ec); + std::error_code dest_ec; + const trx::fs::path normalized_dest = trx::fs::weakly_canonical(dest_path, dest_ec); + const bool same_directory = !source_ec && !dest_ec && source_path == normalized_dest; + + if (!same_directory) { + copy_dir(source_dir, filename); + } + + const trx::fs::path final_header_path = dest_path / "header.json"; + std::ofstream out_json(final_header_path, std::ios::out | std::ios::trunc); + if (!out_json.is_open()) { + throw std::runtime_error("Failed to write header.json to: " + final_header_path.string()); + } + out_json << header.dump() << std::endl; + out_json.close(); + ec.clear(); if (!trx::fs::exists(filename, ec) || !trx::fs::is_directory(filename, ec)) { - rm_dir(tmp_dir); throw std::runtime_error("Failed to create output directory: " + filename); } - const trx::fs::path header_path = dest_path / "header.json"; - if (!trx::fs::exists(header_path)) { - rm_dir(tmp_dir); - throw std::runtime_error("Missing header.json in output directory: " + header_path.string()); + if (!trx::fs::exists(final_header_path)) { + throw std::runtime_error("Missing header.json in output directory: " + final_header_path.string()); } } - - rm_dir(tmp_dir); } void populate_fps(const string &name, std::map> &files_pointer_size) { @@ -806,8 +945,15 @@ std::string make_temp_dir(const std::string &prefix) { static std::mt19937_64 rng(std::random_device{}()); std::uniform_int_distribution dist; + const uint64_t pid = +#if defined(_WIN32) || defined(_WIN64) + static_cast(_getpid()); +#else + static_cast(getpid()); +#endif for (int attempt = 0; attempt < 100; ++attempt) { - const trx::fs::path candidate = base_path / (prefix + "_" + std::to_string(dist(rng))); + const trx::fs::path candidate = + base_path / (prefix + "_" + std::to_string(pid) + "_" + std::to_string(dist(rng))); ec.clear(); if (trx::fs::create_directory(candidate, ec)) { return candidate.string(); @@ -906,7 +1052,8 @@ std::string extract_zip_to_directory(zip_t *zfolder) { void zip_from_folder(zip_t *zf, const std::string &root, const std::string &directory, - zip_uint32_t compression_standard) { + zip_uint32_t compression_standard, + const std::unordered_set *skip) { std::error_code ec; for (trx::fs::recursive_directory_iterator it(directory, ec), end; it != end; it.increment(ec)) { if (ec) { @@ -928,6 +1075,10 @@ void zip_from_folder(zip_t *zf, if (source == nullptr) { throw std::runtime_error(std::string("Error adding file ") + zip_fname + ": " + zip_strerror(zf)); } + if (skip && skip->find(zip_fname) != skip->end()) { + zip_source_free(source); + continue; + } const zip_int64_t file_idx = zip_file_add(zf, zip_fname.c_str(), source, ZIP_FL_ENC_UTF_8); if (file_idx < 0) { zip_source_free(source); @@ -949,4 +1100,500 @@ std::string rm_root(const std::string &root, const std::string &path) { } return stripped; } + +namespace { +TrxScalarType scalar_type_from_dtype(const std::string &dtype) { + if (dtype == "float16") { + return TrxScalarType::Float16; + } + if (dtype == "float32") { + return TrxScalarType::Float32; + } + if (dtype == "float64") { + return TrxScalarType::Float64; + } + return TrxScalarType::Float32; +} + +std::string typed_array_filename(const std::string &base, const TypedArray &arr) { + if (arr.cols <= 1) { + return base + "." + arr.dtype; + } + return base + "." + std::to_string(arr.cols) + "." + arr.dtype; +} + +void write_typed_array_file(const std::string &path, const TypedArray &arr) { + const auto bytes = arr.to_bytes(); + std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open output file: " + path); + } + if (bytes.data && bytes.size > 0) { + out.write(reinterpret_cast(bytes.data), static_cast(bytes.size)); + } + out.flush(); + out.close(); +} +} // namespace + +void AnyTrxFile::for_each_positions_chunk(size_t chunk_bytes, const PositionsChunkCallback &fn) const { + if (positions.empty()) { + throw std::runtime_error("TRX positions are empty."); + } + if (positions.cols != 3) { + throw std::runtime_error("Positions must have 3 columns."); + } + if (!fn) { + return; + } + const size_t elem_size = static_cast(detail::_sizeof_dtype(positions.dtype)); + const size_t bytes_per_point = elem_size * 3; + const size_t total_points = static_cast(positions.rows); + size_t points_per_chunk = 0; + if (chunk_bytes == 0) { + points_per_chunk = total_points; + } else { + points_per_chunk = std::max(1, chunk_bytes / bytes_per_point); + } + const auto bytes = positions.to_bytes(); + const auto *base = bytes.data; + const auto dtype = scalar_type_from_dtype(positions.dtype); + for (size_t offset = 0; offset < total_points; offset += points_per_chunk) { + const size_t count = std::min(points_per_chunk, total_points - offset); + const void *ptr = base + offset * bytes_per_point; + fn(dtype, ptr, offset, count); + } +} + +void AnyTrxFile::for_each_positions_chunk_mutable(size_t chunk_bytes, const PositionsChunkMutableCallback &fn) { + if (positions.empty()) { + throw std::runtime_error("TRX positions are empty."); + } + if (positions.cols != 3) { + throw std::runtime_error("Positions must have 3 columns."); + } + if (!fn) { + return; + } + const size_t elem_size = static_cast(detail::_sizeof_dtype(positions.dtype)); + const size_t bytes_per_point = elem_size * 3; + const size_t total_points = static_cast(positions.rows); + size_t points_per_chunk = 0; + if (chunk_bytes == 0) { + points_per_chunk = total_points; + } else { + points_per_chunk = std::max(1, chunk_bytes / bytes_per_point); + } + const auto bytes = positions.to_bytes_mutable(); + auto *base = bytes.data; + const auto dtype = scalar_type_from_dtype(positions.dtype); + for (size_t offset = 0; offset < total_points; offset += points_per_chunk) { + const size_t count = std::min(points_per_chunk, total_points - offset); + void *ptr = base + offset * bytes_per_point; + fn(dtype, ptr, offset, count); + } +} + +PositionsOutputInfo prepare_positions_output(const AnyTrxFile &input, + const std::string &output_directory, + const PrepareOutputOptions &options) { + if (input.positions.empty() || input.offsets.empty()) { + throw std::runtime_error("Input TRX missing positions/offsets."); + } + if (input.positions.cols != 3) { + throw std::runtime_error("Positions must have 3 columns."); + } + + std::error_code ec; + if (trx::fs::exists(output_directory, ec)) { + if (options.overwrite_existing) { + trx::fs::remove_all(output_directory, ec); + } else { + throw std::runtime_error("Output directory already exists: " + output_directory); + } + } + ec.clear(); + trx::fs::create_directories(output_directory, ec); + if (ec) { + throw std::runtime_error("Failed to create output directory: " + output_directory); + } + + const std::string header_path = output_directory + SEPARATOR + "header.json"; + { + std::ofstream out(header_path, std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to write header.json to: " + header_path); + } + out << input.header.dump() << std::endl; + } + + write_typed_array_file(output_directory + SEPARATOR + typed_array_filename("offsets", input.offsets), input.offsets); + + if (!input.groups.empty()) { + const std::string groups_dir = output_directory + SEPARATOR + "groups"; + trx::fs::create_directories(groups_dir, ec); + for (const auto &kv : input.groups) { + write_typed_array_file(groups_dir + SEPARATOR + typed_array_filename(kv.first, kv.second), kv.second); + } + } + + if (!input.data_per_streamline.empty()) { + const std::string dps_dir = output_directory + SEPARATOR + "dps"; + trx::fs::create_directories(dps_dir, ec); + for (const auto &kv : input.data_per_streamline) { + write_typed_array_file(dps_dir + SEPARATOR + typed_array_filename(kv.first, kv.second), kv.second); + } + } + + if (!input.data_per_vertex.empty()) { + const std::string dpv_dir = output_directory + SEPARATOR + "dpv"; + trx::fs::create_directories(dpv_dir, ec); + for (const auto &kv : input.data_per_vertex) { + write_typed_array_file(dpv_dir + SEPARATOR + typed_array_filename(kv.first, kv.second), kv.second); + } + } + + if (!input.data_per_group.empty()) { + const std::string dpg_dir = output_directory + SEPARATOR + "dpg"; + trx::fs::create_directories(dpg_dir, ec); + for (const auto &group_kv : input.data_per_group) { + const std::string group_dir = dpg_dir + SEPARATOR + group_kv.first; + trx::fs::create_directories(group_dir, ec); + for (const auto &kv : group_kv.second) { + write_typed_array_file(group_dir + SEPARATOR + typed_array_filename(kv.first, kv.second), kv.second); + } + } + } + + PositionsOutputInfo info; + info.directory = output_directory; + info.dtype = input.positions.dtype; + info.points = static_cast(input.positions.rows); + info.positions_path = output_directory + SEPARATOR + typed_array_filename("positions", input.positions); + return info; +} + +void merge_trx_shards(const MergeTrxShardsOptions &options) { + if (options.shard_directories.empty()) { + throw std::invalid_argument("merge_trx_shards requires at least one shard directory"); + } + + auto read_header = [](const std::string &dir) { + const std::string path = dir + SEPARATOR + "header.json"; + std::ifstream in(path); + if (!in.is_open()) { + throw std::runtime_error("Failed to open shard header: " + path); + } + std::stringstream ss; + ss << in.rdbuf(); + std::string err; + json parsed = json::parse(ss.str(), err); + if (!err.empty()) { + throw std::runtime_error("Failed to parse shard header " + path + ": " + err); + } + return parsed; + }; + + auto find_file_with_prefix = [](const std::string &dir, const std::string &prefix) -> std::string { + std::error_code ec; + for (trx::fs::directory_iterator it(dir, ec), end; it != end; it.increment(ec)) { + if (ec) { + break; + } + if (!it->is_regular_file()) { + continue; + } + const std::string name = it->path().filename().string(); + if (name.rfind(prefix, 0) == 0) { + return it->path().string(); + } + } + return ""; + }; + + auto append_binary = [](const std::string &dst, const std::string &src) { + std::ifstream in(src, std::ios::binary); + if (!in.is_open()) { + throw std::runtime_error("Failed to open source for append: " + src); + } + std::ofstream out(dst, std::ios::binary | std::ios::app); + if (!out.is_open()) { + throw std::runtime_error("Failed to open destination for append: " + dst); + } + std::vector buffer(1 << 20); + while (in) { + in.read(buffer.data(), static_cast(buffer.size())); + const auto n = in.gcount(); + if (n > 0) { + out.write(buffer.data(), n); + } + } + }; + + auto append_offsets_with_base = [](const std::string &dst, const std::string &src, uint64_t base_vertices, bool skip_first) { + std::ifstream in(src, std::ios::binary); + if (!in.is_open()) { + throw std::runtime_error("Failed to open source offsets: " + src); + } + std::ofstream out(dst, std::ios::binary | std::ios::app); + if (!out.is_open()) { + throw std::runtime_error("Failed to open destination offsets: " + dst); + } + constexpr size_t kChunkElems = (8 * 1024 * 1024) / sizeof(uint64_t); + std::vector buffer(kChunkElems); + bool first_value_pending = skip_first; + while (in) { + in.read(reinterpret_cast(buffer.data()), + static_cast(buffer.size() * sizeof(uint64_t))); + const std::streamsize bytes = in.gcount(); + if (bytes <= 0) { + break; + } + if (bytes % static_cast(sizeof(uint64_t)) != 0) { + throw std::runtime_error("Offsets file has invalid byte count: " + src); + } + const size_t count = static_cast(bytes) / sizeof(uint64_t); + size_t start_index = 0; + if (first_value_pending) { + if (count == 0) { + continue; + } + start_index = 1; + first_value_pending = false; + } + for (size_t i = start_index; i < count; ++i) { + buffer[i] += base_vertices; + } + if (count > start_index) { + out.write(reinterpret_cast(buffer.data() + start_index), + static_cast((count - start_index) * sizeof(uint64_t))); + } + } + }; + + auto append_group_indices_with_base = [](const std::string &dst, const std::string &src, uint32_t base_streamlines) { + std::ifstream in(src, std::ios::binary); + if (!in.is_open()) { + throw std::runtime_error("Failed to open source group file: " + src); + } + std::ofstream out(dst, std::ios::binary | std::ios::app); + if (!out.is_open()) { + throw std::runtime_error("Failed to open destination group file: " + dst); + } + constexpr size_t kChunkElems = (8 * 1024 * 1024) / sizeof(uint32_t); + std::vector buffer(kChunkElems); + while (in) { + in.read(reinterpret_cast(buffer.data()), + static_cast(buffer.size() * sizeof(uint32_t))); + const std::streamsize bytes = in.gcount(); + if (bytes <= 0) { + break; + } + if (bytes % static_cast(sizeof(uint32_t)) != 0) { + throw std::runtime_error("Group file has invalid byte count: " + src); + } + const size_t count = static_cast(bytes) / sizeof(uint32_t); + for (size_t i = 0; i < count; ++i) { + buffer[i] += base_streamlines; + } + out.write(reinterpret_cast(buffer.data()), static_cast(count * sizeof(uint32_t))); + } + }; + + auto list_subdir_files = [](const std::string &dir, const std::string &subdir) { + std::vector files; + std::error_code ec; + const trx::fs::path path = trx::fs::path(dir) / subdir; + if (!trx::fs::exists(path, ec)) { + return files; + } + if (!trx::fs::is_directory(path, ec)) { + throw std::runtime_error("Expected directory for subdir: " + path.string()); + } + for (trx::fs::directory_iterator it(path, ec), end; it != end; it.increment(ec)) { + if (ec) { + throw std::runtime_error("Failed to read directory: " + path.string()); + } + if (!it->is_regular_file()) { + continue; + } + files.push_back(it->path().filename().string()); + } + std::sort(files.begin(), files.end()); + return files; + }; + + auto ensure_schema_match = [&](const std::string &subdir, const std::vector &schema_files, const std::string &shard) { + const auto shard_files = list_subdir_files(shard, subdir); + if (shard_files != schema_files) { + throw std::runtime_error("Shard schema mismatch for subdir '" + subdir + "': " + shard); + } + }; + + std::error_code ec; + for (const auto &dir : options.shard_directories) { + if (!trx::fs::exists(dir, ec) || !trx::fs::is_directory(dir, ec)) { + throw std::runtime_error("Shard directory does not exist: " + dir); + } + } + + const std::string output_dir = options.output_directory ? options.output_path : make_temp_dir("trx_merge"); + if (trx::fs::exists(output_dir, ec)) { + if (!options.overwrite_existing) { + throw std::runtime_error("Output already exists: " + output_dir); + } + trx::fs::remove_all(output_dir, ec); + } + trx::fs::create_directories(output_dir, ec); + if (ec) { + throw std::runtime_error("Failed to create output directory: " + output_dir); + } + + for (const auto &dir : options.shard_directories) { + if (trx::fs::exists(dir + SEPARATOR + "dpg", ec)) { + throw std::runtime_error("merge_trx_shards currently does not support dpg/ merges"); + } + } + + json merged_header = read_header(options.shard_directories.front()); + const std::string first_positions = find_file_with_prefix(options.shard_directories.front(), "positions."); + const std::string first_offsets = find_file_with_prefix(options.shard_directories.front(), "offsets."); + if (first_positions.empty() || first_offsets.empty()) { + throw std::runtime_error("Shard missing positions/offsets: " + options.shard_directories.front()); + } + if (get_ext(first_offsets) != "uint64") { + throw std::runtime_error("merge_trx_shards currently requires offsets.uint64"); + } + + const std::string positions_filename = trx::fs::path(first_positions).filename().string(); + const std::string offsets_filename = trx::fs::path(first_offsets).filename().string(); + const std::string positions_out = output_dir + SEPARATOR + positions_filename; + const std::string offsets_out = output_dir + SEPARATOR + offsets_filename; + { + std::ofstream clear_positions(positions_out, std::ios::binary | std::ios::out | std::ios::trunc); + if (!clear_positions.is_open()) { + throw std::runtime_error("Failed to create output positions file: " + positions_out); + } + } + { + std::ofstream clear_offsets(offsets_out, std::ios::binary | std::ios::out | std::ios::trunc); + if (!clear_offsets.is_open()) { + throw std::runtime_error("Failed to create output offsets file: " + offsets_out); + } + } + + const auto dps_schema = list_subdir_files(options.shard_directories.front(), "dps"); + const auto dpv_schema = list_subdir_files(options.shard_directories.front(), "dpv"); + const auto groups_schema = list_subdir_files(options.shard_directories.front(), "groups"); + if (!dps_schema.empty()) { + trx::fs::create_directories(output_dir + SEPARATOR + "dps", ec); + } + if (!dpv_schema.empty()) { + trx::fs::create_directories(output_dir + SEPARATOR + "dpv", ec); + } + if (!groups_schema.empty()) { + trx::fs::create_directories(output_dir + SEPARATOR + "groups", ec); + } + for (const auto &name : dps_schema) { + std::ofstream clear_file(output_dir + SEPARATOR + "dps" + SEPARATOR + name, std::ios::binary | std::ios::out | std::ios::trunc); + if (!clear_file.is_open()) { + throw std::runtime_error("Failed to create merged dps file: " + name); + } + } + for (const auto &name : dpv_schema) { + std::ofstream clear_file(output_dir + SEPARATOR + "dpv" + SEPARATOR + name, std::ios::binary | std::ios::out | std::ios::trunc); + if (!clear_file.is_open()) { + throw std::runtime_error("Failed to create merged dpv file: " + name); + } + } + for (const auto &name : groups_schema) { + std::ofstream clear_file(output_dir + SEPARATOR + "groups" + SEPARATOR + name, + std::ios::binary | std::ios::out | std::ios::trunc); + if (!clear_file.is_open()) { + throw std::runtime_error("Failed to create merged group file: " + name); + } + } + + uint64_t total_vertices = 0; + uint64_t total_streamlines = 0; + for (size_t i = 0; i < options.shard_directories.size(); ++i) { + const std::string &shard_dir = options.shard_directories[i]; + ensure_schema_match("dps", dps_schema, shard_dir); + ensure_schema_match("dpv", dpv_schema, shard_dir); + ensure_schema_match("groups", groups_schema, shard_dir); + + const json shard_header = read_header(shard_dir); + const uint64_t shard_vertices = static_cast(shard_header["NB_VERTICES"].int_value()); + const uint64_t shard_streamlines = static_cast(shard_header["NB_STREAMLINES"].int_value()); + + const std::string shard_positions = find_file_with_prefix(shard_dir, "positions."); + const std::string shard_offsets = find_file_with_prefix(shard_dir, "offsets."); + if (shard_positions.empty() || shard_offsets.empty()) { + throw std::runtime_error("Shard missing positions/offsets: " + shard_dir); + } + if (trx::fs::path(shard_positions).filename().string() != positions_filename) { + throw std::runtime_error("Shard positions dtype mismatch: " + shard_dir); + } + if (trx::fs::path(shard_offsets).filename().string() != offsets_filename) { + throw std::runtime_error("Shard offsets dtype mismatch: " + shard_dir); + } + + append_binary(positions_out, shard_positions); + append_offsets_with_base(offsets_out, shard_offsets, total_vertices, i != 0); + + for (const auto &name : dps_schema) { + append_binary(output_dir + SEPARATOR + "dps" + SEPARATOR + name, shard_dir + SEPARATOR + "dps" + SEPARATOR + name); + } + for (const auto &name : dpv_schema) { + append_binary(output_dir + SEPARATOR + "dpv" + SEPARATOR + name, shard_dir + SEPARATOR + "dpv" + SEPARATOR + name); + } + for (const auto &name : groups_schema) { + if (total_streamlines > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Group index offset exceeds uint32 range during merge"); + } + append_group_indices_with_base( + output_dir + SEPARATOR + "groups" + SEPARATOR + name, + shard_dir + SEPARATOR + "groups" + SEPARATOR + name, + static_cast(total_streamlines)); + } + + total_vertices += shard_vertices; + total_streamlines += shard_streamlines; + } + + merged_header = _json_set(merged_header, "NB_VERTICES", static_cast(total_vertices)); + merged_header = _json_set(merged_header, "NB_STREAMLINES", static_cast(total_streamlines)); + { + const std::string merged_header_path = output_dir + SEPARATOR + "header.json"; + std::ofstream out(merged_header_path, std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to write merged header: " + merged_header_path); + } + out << merged_header.dump() << std::endl; + } + + if (options.output_directory) { + return; + } + + const trx::fs::path archive_path(options.output_path); + if (archive_path.has_parent_path()) { + std::error_code parent_ec; + trx::fs::create_directories(archive_path.parent_path(), parent_ec); + if (parent_ec) { + throw std::runtime_error("Could not create archive parent directory: " + archive_path.parent_path().string()); + } + } + + int errorp = 0; + zip_t *zf = zip_open(options.output_path.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp); + if (zf == nullptr) { + throw std::runtime_error("Could not open archive " + options.output_path + ": " + strerror(errorp)); + } + zip_from_folder(zf, output_dir, output_dir, options.compression_standard, nullptr); + if (zip_close(zf) != 0) { + throw std::runtime_error("Unable to close archive " + options.output_path + ": " + zip_strerror(zf)); + } + rm_dir(output_dir); +} }; // namespace trx \ No newline at end of file diff --git a/tests/test_trx_anytrxfile.cpp b/tests/test_trx_anytrxfile.cpp index 6d66463..9803d53 100644 --- a/tests/test_trx_anytrxfile.cpp +++ b/tests/test_trx_anytrxfile.cpp @@ -232,6 +232,136 @@ void expect_basic_consistency(const AnyTrxFile &trx) { EXPECT_NE(bytes.data, nullptr); EXPECT_GT(bytes.size, 0U); } + +fs::path write_test_shard_with_dtype(const fs::path &root, + const std::string &name, + const std::vector> &points, + const std::vector &offsets, + const std::string &positions_dtype) { + const fs::path shard_dir = root / name; + std::error_code ec; + fs::create_directories(shard_dir, ec); + + json::object header_obj; + header_obj["DIMENSIONS"] = json::array{1, 1, 1}; + header_obj["NB_STREAMLINES"] = static_cast(offsets.size() - 1); + header_obj["NB_VERTICES"] = static_cast(points.size()); + header_obj["VOXEL_TO_RASMM"] = json::array{ + json::array{1.0, 0.0, 0.0, 0.0}, + json::array{0.0, 1.0, 0.0, 0.0}, + json::array{0.0, 0.0, 1.0, 0.0}, + json::array{0.0, 0.0, 0.0, 1.0}, + }; + write_header_file(shard_dir, json(header_obj)); + + const std::string pos_file = (shard_dir / ("positions.3." + positions_dtype)).string(); + if (positions_dtype == "float16") { + Eigen::Matrix m( + static_cast(points.size()), 3); + for (size_t i = 0; i < points.size(); ++i) { + m(static_cast(i), 0) = static_cast(points[i][0]); + m(static_cast(i), 1) = static_cast(points[i][1]); + m(static_cast(i), 2) = static_cast(points[i][2]); + } + trx::write_binary(pos_file, m); + } else if (positions_dtype == "float64") { + Eigen::Matrix m( + static_cast(points.size()), 3); + for (size_t i = 0; i < points.size(); ++i) { + m(static_cast(i), 0) = static_cast(points[i][0]); + m(static_cast(i), 1) = static_cast(points[i][1]); + m(static_cast(i), 2) = static_cast(points[i][2]); + } + trx::write_binary(pos_file, m); + } else { + Eigen::Matrix m( + static_cast(points.size()), 3); + for (size_t i = 0; i < points.size(); ++i) { + m(static_cast(i), 0) = points[i][0]; + m(static_cast(i), 1) = points[i][1]; + m(static_cast(i), 2) = points[i][2]; + } + trx::write_binary(pos_file, m); + } + + Eigen::Matrix offsets_mat( + static_cast(offsets.size()), 1); + for (size_t i = 0; i < offsets.size(); ++i) { + offsets_mat(static_cast(i), 0) = offsets[i]; + } + trx::write_binary((shard_dir / "offsets.uint64").string(), offsets_mat); + + return shard_dir; +} + +fs::path write_test_shard(const fs::path &root, + const std::string &name, + const std::vector> &points, + const std::vector &offsets, + const std::vector &dps_values, + const std::vector &dpv_values, + const std::vector &group_indices) { + const fs::path shard_dir = root / name; + std::error_code ec; + fs::create_directories(shard_dir, ec); + if (ec) { + throw std::runtime_error("Failed to create shard directory: " + shard_dir.string()); + } + + json::object header_obj; + header_obj["DIMENSIONS"] = json::array{1, 1, 1}; + header_obj["NB_STREAMLINES"] = static_cast(offsets.size() - 1); + header_obj["NB_VERTICES"] = static_cast(points.size()); + header_obj["VOXEL_TO_RASMM"] = json::array{ + json::array{1.0, 0.0, 0.0, 0.0}, + json::array{0.0, 1.0, 0.0, 0.0}, + json::array{0.0, 0.0, 1.0, 0.0}, + json::array{0.0, 0.0, 0.0, 1.0}, + }; + write_header_file(shard_dir, json(header_obj)); + + Eigen::Matrix positions( + static_cast(points.size()), 3); + for (size_t i = 0; i < points.size(); ++i) { + positions(static_cast(i), 0) = points[i][0]; + positions(static_cast(i), 1) = points[i][1]; + positions(static_cast(i), 2) = points[i][2]; + } + trx::write_binary((shard_dir / "positions.3.float32").string(), positions); + + Eigen::Matrix offsets_mat( + static_cast(offsets.size()), 1); + for (size_t i = 0; i < offsets.size(); ++i) { + offsets_mat(static_cast(i), 0) = offsets[i]; + } + trx::write_binary((shard_dir / "offsets.uint64").string(), offsets_mat); + + fs::create_directories(shard_dir / "dps", ec); + Eigen::Matrix dps_mat( + static_cast(dps_values.size()), 1); + for (size_t i = 0; i < dps_values.size(); ++i) { + dps_mat(static_cast(i), 0) = dps_values[i]; + } + trx::write_binary((shard_dir / "dps" / "weight.float32").string(), dps_mat); + + fs::create_directories(shard_dir / "dpv", ec); + Eigen::Matrix dpv_mat( + static_cast(dpv_values.size()), 1); + for (size_t i = 0; i < dpv_values.size(); ++i) { + dpv_mat(static_cast(i), 0) = dpv_values[i]; + } + trx::write_binary((shard_dir / "dpv" / "signal.float32").string(), dpv_mat); + + fs::create_directories(shard_dir / "groups", ec); + Eigen::Matrix groups_mat( + static_cast(group_indices.size()), 1); + for (size_t i = 0; i < group_indices.size(); ++i) { + groups_mat(static_cast(i), 0) = group_indices[i]; + } + trx::write_binary((shard_dir / "groups" / "Bundle.uint32").string(), groups_mat); + + return shard_dir; +} } // namespace TEST(AnyTrxFile, LoadZipAndValidate) { @@ -639,3 +769,394 @@ TEST(AnyTrxFile, SaveRejectsMissingBackingDirectory) { std::error_code ec; fs::remove_all(temp_dir, ec); } + +TEST(AnyTrxFile, MergeTrxShardsDirectoryOutput) { + const fs::path temp_root = make_temp_test_dir("trx_merge_shards"); + const fs::path shard1 = write_test_shard(temp_root, + "shard1", + {{0.0F, 0.0F, 0.0F}, {1.0F, 0.0F, 0.0F}}, + {0, 2}, + {10.0F}, + {0.1F, 0.2F}, + {0}); + const fs::path shard2 = write_test_shard(temp_root, + "shard2", + {{2.0F, 0.0F, 0.0F}, {3.0F, 0.0F, 0.0F}, {4.0F, 0.0F, 0.0F}}, + {0, 1, 3}, + {20.0F, 30.0F}, + {0.3F, 0.4F, 0.5F}, + {1}); + + const fs::path output_dir = temp_root / "merged"; + MergeTrxShardsOptions options; + options.shard_directories = {shard1.string(), shard2.string()}; + options.output_path = output_dir.string(); + options.output_directory = true; + merge_trx_shards(options); + + auto merged = load_any(output_dir.string()); + EXPECT_EQ(merged.num_streamlines(), 3U); + EXPECT_EQ(merged.num_vertices(), 5U); + ASSERT_EQ(merged.offsets_u64.size(), 4U); + EXPECT_EQ(merged.offsets_u64[0], 0U); + EXPECT_EQ(merged.offsets_u64[1], 2U); + EXPECT_EQ(merged.offsets_u64[2], 3U); + EXPECT_EQ(merged.offsets_u64[3], 5U); + + const auto pos = merged.positions.as_matrix(); + EXPECT_FLOAT_EQ(pos(0, 0), 0.0F); + EXPECT_FLOAT_EQ(pos(4, 0), 4.0F); + + auto dps_it = merged.data_per_streamline.find("weight"); + ASSERT_NE(dps_it, merged.data_per_streamline.end()); + auto dps = dps_it->second.as_matrix(); + EXPECT_EQ(dps.rows(), 3); + EXPECT_FLOAT_EQ(dps(0, 0), 10.0F); + EXPECT_FLOAT_EQ(dps(2, 0), 30.0F); + + auto dpv_it = merged.data_per_vertex.find("signal"); + ASSERT_NE(dpv_it, merged.data_per_vertex.end()); + auto dpv = dpv_it->second.as_matrix(); + EXPECT_EQ(dpv.rows(), 5); + EXPECT_FLOAT_EQ(dpv(0, 0), 0.1F); + EXPECT_FLOAT_EQ(dpv(4, 0), 0.5F); + + auto group_it = merged.groups.find("Bundle"); + ASSERT_NE(group_it, merged.groups.end()); + auto grp = group_it->second.as_matrix(); + ASSERT_EQ(grp.rows(), 2); + EXPECT_EQ(grp(0, 0), 0U); + EXPECT_EQ(grp(1, 0), 2U); + merged.close(); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, MergeTrxShardsSchemaMismatchThrows) { + const fs::path temp_root = make_temp_test_dir("trx_merge_schema"); + const fs::path shard1 = write_test_shard(temp_root, + "shard1", + {{0.0F, 0.0F, 0.0F}, {1.0F, 0.0F, 0.0F}}, + {0, 2}, + {10.0F}, + {0.1F, 0.2F}, + {0}); + const fs::path shard2 = write_test_shard(temp_root, + "shard2", + {{2.0F, 0.0F, 0.0F}}, + {0, 1}, + {20.0F}, + {0.3F}, + {0}); + + std::error_code ec; + fs::remove(shard2 / "dpv" / "signal.float32", ec); + + const fs::path output_dir = temp_root / "merged"; + MergeTrxShardsOptions options; + options.shard_directories = {shard1.string(), shard2.string()}; + options.output_path = output_dir.string(); + options.output_directory = true; + EXPECT_THROW(merge_trx_shards(options), std::runtime_error); + + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, PreparePositionsOutputCopiesMetadataAndOffsets) { + const fs::path temp_root = make_temp_test_dir("trx_prepare_positions"); + const fs::path shard1 = write_test_shard(temp_root, + "shard1", + {{0.0F, 0.0F, 0.0F}, {1.0F, 0.0F, 0.0F}}, + {0, 2}, + {10.0F}, + {0.1F, 0.2F}, + {0}); + auto input = load_any(shard1.string()); + + const fs::path output_dir = temp_root / "prepared"; + PrepareOutputOptions options; + options.overwrite_existing = true; + const auto info = prepare_positions_output(input, output_dir.string(), options); + + EXPECT_EQ(info.directory, output_dir.string()); + EXPECT_EQ(info.dtype, "float32"); + EXPECT_EQ(info.points, 2U); + EXPECT_EQ(fs::path(info.positions_path).filename().string(), "positions.3.float32"); + EXPECT_TRUE(fs::exists(output_dir / "header.json")); + EXPECT_TRUE(fs::exists(output_dir / "offsets.uint64")); + EXPECT_TRUE(fs::exists(output_dir / "dps" / "weight.float32")); + EXPECT_TRUE(fs::exists(output_dir / "dpv" / "signal.float32")); + EXPECT_TRUE(fs::exists(output_dir / "groups" / "Bundle.uint32")); + EXPECT_FALSE(fs::exists(info.positions_path)); + + input.close(); + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, PositionsChunkIterationAndMutation) { + const fs::path temp_root = make_temp_test_dir("trx_positions_chunk"); + const fs::path shard1 = write_test_shard(temp_root, + "shard1", + {{0.0F, 0.0F, 0.0F}, {1.0F, 2.0F, 3.0F}}, + {0, 2}, + {10.0F}, + {0.1F, 0.2F}, + {0}); + auto trx = load_any(shard1.string()); + + size_t total_points = 0; + size_t callbacks = 0; + trx.for_each_positions_chunk(12, [&](TrxScalarType dtype, const void *data, size_t point_offset, size_t point_count) { + EXPECT_EQ(dtype, TrxScalarType::Float32); + EXPECT_NE(data, nullptr); + EXPECT_EQ(point_count, 1U); + EXPECT_LT(point_offset, 2U); + total_points += point_count; + callbacks += 1; + }); + EXPECT_EQ(total_points, 2U); + EXPECT_EQ(callbacks, 2U); + + trx.for_each_positions_chunk_mutable(12, [&](TrxScalarType dtype, void *data, size_t, size_t point_count) { + EXPECT_EQ(dtype, TrxScalarType::Float32); + auto *vals = reinterpret_cast(data); + for (size_t i = 0; i < point_count * 3; ++i) { + vals[i] += 1.0F; + } + }); + + const auto positions = trx.positions.as_matrix(); + EXPECT_FLOAT_EQ(positions(0, 0), 1.0F); + EXPECT_FLOAT_EQ(positions(1, 0), 2.0F); + EXPECT_FLOAT_EQ(positions(1, 1), 3.0F); + EXPECT_FLOAT_EQ(positions(1, 2), 4.0F); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, MergeTrxShardsArchiveOutput) { + const fs::path temp_root = make_temp_test_dir("trx_merge_shards_archive"); + const fs::path shard1 = write_test_shard(temp_root, + "shard1", + {{0.0F, 0.0F, 0.0F}, {1.0F, 0.0F, 0.0F}}, + {0, 2}, + {10.0F}, + {0.1F, 0.2F}, + {0}); + const fs::path shard2 = write_test_shard(temp_root, + "shard2", + {{2.0F, 0.0F, 0.0F}, {3.0F, 0.0F, 0.0F}}, + {0, 2}, + {20.0F}, + {0.3F, 0.4F}, + {0}); + + const fs::path output_archive = temp_root / "merged.trx"; + MergeTrxShardsOptions options; + options.shard_directories = {shard1.string(), shard2.string()}; + options.output_path = output_archive.string(); + options.output_directory = false; + merge_trx_shards(options); + + ASSERT_TRUE(fs::exists(output_archive)); + auto merged = load_any(output_archive.string()); + EXPECT_EQ(merged.num_streamlines(), 2U); + EXPECT_EQ(merged.num_vertices(), 4U); + merged.close(); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, MergeTrxShardsRejectsDpg) { + const fs::path temp_root = make_temp_test_dir("trx_merge_shards_dpg"); + const fs::path shard1 = write_test_shard(temp_root, + "shard1", + {{0.0F, 0.0F, 0.0F}}, + {0, 1}, + {10.0F}, + {0.1F}, + {0}); + const fs::path shard2 = write_test_shard(temp_root, + "shard2", + {{1.0F, 0.0F, 0.0F}}, + {0, 1}, + {20.0F}, + {0.2F}, + {0}); + std::error_code ec; + fs::create_directories(shard2 / "dpg" / "Bundle", ec); + std::ofstream out((shard2 / "dpg" / "Bundle" / "mean.float32").string(), std::ios::binary | std::ios::trunc); + float one = 1.0F; + out.write(reinterpret_cast(&one), sizeof(float)); + out.close(); + + MergeTrxShardsOptions options; + options.shard_directories = {shard1.string(), shard2.string()}; + options.output_path = (temp_root / "merged").string(); + options.output_directory = true; + EXPECT_THROW(merge_trx_shards(options), std::runtime_error); + + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, PreparePositionsOutputOverwriteFalseThrows) { + const fs::path temp_root = make_temp_test_dir("trx_prepare_positions_overwrite"); + const fs::path shard1 = write_test_shard(temp_root, + "shard1", + {{0.0F, 0.0F, 0.0F}}, + {0, 1}, + {10.0F}, + {0.1F}, + {0}); + auto input = load_any(shard1.string()); + const fs::path output_dir = temp_root / "prepared"; + std::error_code ec; + fs::create_directories(output_dir, ec); + ASSERT_TRUE(fs::exists(output_dir)); + + PrepareOutputOptions options; + options.overwrite_existing = false; + EXPECT_THROW(prepare_positions_output(input, output_dir.string(), options), std::runtime_error); + + input.close(); + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, SaveRespectsExplicitMode) { + const fs::path temp_root = make_temp_test_dir("trx_any_save_modes"); + const fs::path shard1 = write_test_shard(temp_root, + "shard1", + {{0.0F, 0.0F, 0.0F}, {1.0F, 0.0F, 0.0F}}, + {0, 2}, + {10.0F}, + {0.1F, 0.2F}, + {0}); + auto trx = load_any(shard1.string()); + + const fs::path dir_out = temp_root / "save_dir_mode"; + TrxSaveOptions dir_opts; + dir_opts.mode = TrxSaveMode::Directory; + trx.save(dir_out.string(), dir_opts); + EXPECT_TRUE(fs::is_directory(dir_out)); + EXPECT_TRUE(fs::exists(dir_out / "header.json")); + + const fs::path archive_out = temp_root / "save_archive_mode.trx"; + TrxSaveOptions archive_opts; + archive_opts.mode = TrxSaveMode::Archive; + archive_opts.compression_standard = ZIP_CM_STORE; + trx.save(archive_out.string(), archive_opts); + EXPECT_TRUE(fs::is_regular_file(archive_out)); + + auto dir_loaded = load_any(dir_out.string()); + auto arc_loaded = load_any(archive_out.string()); + EXPECT_EQ(dir_loaded.num_streamlines(), trx.num_streamlines()); + EXPECT_EQ(arc_loaded.num_vertices(), trx.num_vertices()); + dir_loaded.close(); + arc_loaded.close(); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, GetDpsAndDpvReturnCorrectArrays) { + const fs::path temp_root = make_temp_test_dir("trx_get_dps_dpv"); + const fs::path shard = write_test_shard(temp_root, + "s1", + {{0.0F, 0.0F, 0.0F}, {1.0F, 0.0F, 0.0F}}, + {0, 2}, + {42.0F}, + {0.1F, 0.2F}, + {0}); + + auto trx = load_any(shard.string()); + + const TypedArray *dps = trx.get_dps("weight"); + ASSERT_NE(dps, nullptr); + EXPECT_EQ(dps->rows, 1); + EXPECT_EQ(dps->cols, 1); + + const TypedArray *dps_missing = trx.get_dps("no_such_field"); + EXPECT_EQ(dps_missing, nullptr); + + const TypedArray *dpv = trx.get_dpv("signal"); + ASSERT_NE(dpv, nullptr); + EXPECT_EQ(dpv->rows, 2); + + const TypedArray *dpv_missing = trx.get_dpv("no_such_field"); + EXPECT_EQ(dpv_missing, nullptr); + + trx.close(); + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, GetStreamlineFloat32) { + const fs::path temp_root = make_temp_test_dir("trx_get_sl_f32"); + const fs::path shard = write_test_shard_with_dtype(temp_root, + "s1", + {{1.0F, 2.0F, 3.0F}, {4.0F, 5.0F, 6.0F}}, + {0, 2}, + "float32"); + + auto trx = load_any(shard.string()); + ASSERT_EQ(trx.num_streamlines(), 1u); + + auto sl = trx.get_streamline(0); + ASSERT_EQ(sl.size(), 2u); + EXPECT_NEAR(sl[0][0], 1.0, 1e-5); + EXPECT_NEAR(sl[1][2], 6.0, 1e-5); + + EXPECT_THROW(trx.get_streamline(1), std::out_of_range); + + trx.close(); + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, GetStreamlineFloat16) { + const fs::path temp_root = make_temp_test_dir("trx_get_sl_f16"); + const fs::path shard = write_test_shard_with_dtype(temp_root, + "s1", + {{1.0F, 2.0F, 3.0F}, {4.0F, 5.0F, 6.0F}}, + {0, 2}, + "float16"); + + auto trx = load_any(shard.string()); + ASSERT_EQ(trx.num_streamlines(), 1u); + + auto sl = trx.get_streamline(0); + ASSERT_EQ(sl.size(), 2u); + EXPECT_NEAR(sl[0][0], 1.0, 1e-2); + EXPECT_NEAR(sl[1][2], 6.0, 1e-2); + + trx.close(); + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, GetStreamlineFloat64) { + const fs::path temp_root = make_temp_test_dir("trx_get_sl_f64"); + const fs::path shard = write_test_shard_with_dtype(temp_root, + "s1", + {{1.0F, 2.0F, 3.0F}, {4.0F, 5.0F, 6.0F}}, + {0, 2}, + "float64"); + + auto trx = load_any(shard.string()); + ASSERT_EQ(trx.num_streamlines(), 1u); + + auto sl = trx.get_streamline(0); + ASSERT_EQ(sl.size(), 2u); + EXPECT_NEAR(sl[0][0], 1.0, 1e-9); + EXPECT_NEAR(sl[1][2], 6.0, 1e-9); + + trx.close(); + std::error_code ec; + fs::remove_all(temp_root, ec); +} diff --git a/tests/test_trx_mmap.cpp b/tests/test_trx_mmap.cpp index d4fb5ef..7eb443b 100644 --- a/tests/test_trx_mmap.cpp +++ b/tests/test_trx_mmap.cpp @@ -181,7 +181,7 @@ TestTrxFixture create_fixture() { if (zf == nullptr) { throw std::runtime_error("Failed to create trx zip file"); } - trx::zip_from_folder(zf, trx_dir.string(), trx_dir.string(), ZIP_CM_STORE); + trx::zip_from_folder(zf, trx_dir.string(), trx_dir.string(), ZIP_CM_STORE, nullptr); if (zip_close(zf) != 0) { throw std::runtime_error("Failed to close trx zip file"); } diff --git a/tests/test_trx_trxfile.cpp b/tests/test_trx_trxfile.cpp index 9bb96d9..2336535 100644 --- a/tests/test_trx_trxfile.cpp +++ b/tests/test_trx_trxfile.cpp @@ -292,6 +292,87 @@ TEST(TrxFileTpp, TrxStreamFinalize) { fs::remove_all(tmp_dir, ec); } +TEST(TrxFileTpp, TrxStreamOnDiskMetadataAllDtypes) { + auto tmp_dir = make_temp_test_dir("trx_ondisk_dtypes"); + const fs::path out_path = tmp_dir / "ondisk.trx"; + + TrxStream proto; + proto.set_metadata_mode(TrxStream::MetadataMode::OnDisk); + + std::vector sl1 = {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + std::vector sl2 = {2.0f, 0.0f, 0.0f}; + proto.push_streamline(sl1); + proto.push_streamline(sl2); + + proto.push_dps_from_vector("w_f16", "float16", std::vector{0.5f, 1.5f}); + proto.push_dps_from_vector("w_f32", "float32", std::vector{0.5f, 1.5f}); + proto.push_dps_from_vector("w_f64", "float64", std::vector{0.5, 1.5}); + + proto.push_dpv_from_vector("s_f16", "float16", std::vector{1.0f, 2.0f, 3.0f}); + proto.push_dpv_from_vector("s_f32", "float32", std::vector{1.0f, 2.0f, 3.0f}); + proto.push_dpv_from_vector("s_f64", "float64", std::vector{1.0, 2.0, 3.0}); + + proto.finalize(out_path.string(), ZIP_CM_STORE); + + auto trx = load_any(out_path.string()); + EXPECT_EQ(trx.num_streamlines(), 2u); + EXPECT_EQ(trx.num_vertices(), 3u); + + for (const auto &key : {"w_f16", "w_f32", "w_f64"}) { + auto it = trx.data_per_streamline.find(key); + ASSERT_NE(it, trx.data_per_streamline.end()) << "missing dps key: " << key; + EXPECT_EQ(it->second.rows, 2); + } + for (const auto &key : {"s_f16", "s_f32", "s_f64"}) { + auto it = trx.data_per_vertex.find(key); + ASSERT_NE(it, trx.data_per_vertex.end()) << "missing dpv key: " << key; + EXPECT_EQ(it->second.rows, 3); + } + + trx.close(); + std::error_code ec; + fs::remove_all(tmp_dir, ec); +} + +TEST(TrxFileTpp, QueryAabbCounts) { + constexpr int kStreamlineCount = 1000; + constexpr int kInsideCount = 250; + constexpr int kPointsPerStreamline = 5; + + const int nb_vertices = kStreamlineCount * kPointsPerStreamline; + trx::TrxFile trx(nb_vertices, kStreamlineCount); + + trx.streamlines->_offsets(0, 0) = 0; + for (int i = 0; i < kStreamlineCount; ++i) { + trx.streamlines->_lengths(i) = kPointsPerStreamline; + trx.streamlines->_offsets(i + 1, 0) = (i + 1) * kPointsPerStreamline; + } + + int cursor = 0; + for (int i = 0; i < kStreamlineCount; ++i) { + const bool inside = i < kInsideCount; + for (int p = 0; p < kPointsPerStreamline; ++p, ++cursor) { + if (inside) { + trx.streamlines->_data(cursor, 0) = -0.8f + 0.05f * static_cast(p); + trx.streamlines->_data(cursor, 1) = 0.3f + 0.1f * static_cast(p); + trx.streamlines->_data(cursor, 2) = 0.1f + 0.05f * static_cast(p); + } else { + trx.streamlines->_data(cursor, 0) = 0.0f; + trx.streamlines->_data(cursor, 1) = 0.0f; + trx.streamlines->_data(cursor, 2) = -1000.0f - static_cast(i); + } + } + } + + const std::array min_corner{ -0.9f, 0.2f, 0.05f }; + const std::array max_corner{ -0.1f, 1.1f, 0.55f }; + + auto subset = trx.query_aabb(min_corner, max_corner); + EXPECT_EQ(subset->num_streamlines(), static_cast(kInsideCount)); + EXPECT_EQ(subset->num_vertices(), static_cast(kInsideCount * kPointsPerStreamline)); + subset->close(); +} + // resize() with default arguments is a no-op when sizes already match. TEST(TrxFileTpp, ResizeNoChange) { const fs::path data_dir = create_float_trx_dir(); @@ -326,3 +407,107 @@ TEST(TrxFileTpp, ResizeDeleteDpgCloses) { std::error_code ec; fs::remove_all(data_dir, ec); } + +TEST(TrxFileTpp, NormalizeForSaveRejectsNonMonotonicOffsets) { + const fs::path data_dir = create_float_trx_dir(); + auto src_reader = load_trx_dir(data_dir); + auto *src = src_reader.get(); + + ASSERT_GE(src->streamlines->_offsets.size(), 3); + src->streamlines->_offsets(1) = 5; + src->streamlines->_offsets(2) = 4; + + EXPECT_THROW(src->normalize_for_save(), std::runtime_error); + + src->close(); + + std::error_code ec; + fs::remove_all(data_dir, ec); +} + +TEST(TrxFileTpp, NormalizeForSaveRecomputesLengthsAndHeader) { + const fs::path data_dir = create_float_trx_dir(); + auto reader = load_trx_dir(data_dir); + auto *trx = reader.get(); + + ASSERT_EQ(trx->streamlines->_offsets.size(), 3); + trx->streamlines->_lengths(0) = 99; + trx->streamlines->_lengths(1) = 99; + trx->header = _json_set(trx->header, "NB_STREAMLINES", 123); + trx->header = _json_set(trx->header, "NB_VERTICES", 456); + + trx->normalize_for_save(); + + EXPECT_EQ(trx->streamlines->_lengths(0), 2u); + EXPECT_EQ(trx->streamlines->_lengths(1), 2u); + EXPECT_EQ(trx->header["NB_STREAMLINES"].int_value(), 2); + EXPECT_EQ(trx->header["NB_VERTICES"].int_value(), 4); + + trx->close(); + std::error_code ec; + fs::remove_all(data_dir, ec); +} + +TEST(TrxFileTpp, LoadFromDirectoryMissingHeader) { + // Directory exists and has files in it, but no header.json. + // Covers the detailed error-diagnostic branch in load_from_directory (lines 980-1006). + auto tmp_dir = make_temp_test_dir("trx_no_header"); + const fs::path dummy = tmp_dir / "positions.3.float32"; + std::ofstream f(dummy.string(), std::ios::binary); + f.close(); + + EXPECT_THROW(TrxFile::load_from_directory(tmp_dir.string()), std::runtime_error); + + std::error_code ec; + fs::remove_all(tmp_dir, ec); +} + +TEST(TrxFileTpp, TrxStreamFloat16Unbuffered) { + // TrxStream("float16") with default unbuffered mode. + // Covers the float16 unbuffered write path in push_streamline (lines 1642-1650) + // and the float16 read-back loop in finalize (lines 1958-1966). + auto tmp_dir = make_temp_test_dir("trx_f16_unbuf"); + const fs::path out_path = tmp_dir / "f16.trx"; + + TrxStream proto("float16"); + std::vector sl1 = {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + std::vector sl2 = {2.0f, 0.0f, 0.0f}; + proto.push_streamline(sl1); + proto.push_streamline(sl2); + proto.finalize(out_path.string(), ZIP_CM_STORE); + + auto trx = load_any(out_path.string()); + EXPECT_EQ(trx.num_streamlines(), 2u); + EXPECT_EQ(trx.num_vertices(), 3u); + trx.close(); + + std::error_code ec; + fs::remove_all(tmp_dir, ec); +} + +TEST(TrxFileTpp, TrxStreamFloat16Buffered) { + // TrxStream("float16") with a small position buffer. + // Pushing two single-point streamlines fills the buffer (6 half-values >= max 6) + // and triggers flush_positions_buffer mid-stream (lines 1592-1603, 1660-1673). + // finalize then calls flush_positions_buffer again with an empty buffer, + // hitting the early-return path (lines 1592-1593). + auto tmp_dir = make_temp_test_dir("trx_f16_buf"); + const fs::path out_path = tmp_dir / "f16_buf.trx"; + + TrxStream proto("float16"); + // 12 bytes / 2 bytes per half = 6 half entries = 2 xyz triplets → flush after 2 points + proto.set_positions_buffer_max_bytes(12); + std::vector pt = {1.0f, 2.0f, 3.0f}; + proto.push_streamline(pt); // buffer: 3 halves + proto.push_streamline(pt); // buffer: 6 halves >= 6 → flush + proto.push_streamline(pt); // buffer: 3 halves (after flush) + proto.finalize(out_path.string(), ZIP_CM_STORE); // flush remainder, then early-return + + auto trx = load_any(out_path.string()); + EXPECT_EQ(trx.num_streamlines(), 3u); + EXPECT_EQ(trx.num_vertices(), 3u); + trx.close(); + + std::error_code ec; + fs::remove_all(tmp_dir, ec); +}