From 7944c2dbbd622a79aca4dc83172ca51bbaa426e1 Mon Sep 17 00:00:00 2001 From: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> Date: Mon, 25 May 2026 17:11:35 +0200 Subject: [PATCH 1/5] Add functions `mapreduce_neighbor` and `mapreduce_neighbor_unsafe` --- src/PointNeighbors.jl | 3 +- src/neighborhood_search.jl | 89 ++++++++++++++++++++++++++++++++----- src/nhs_grid.jl | 21 +++++---- src/nhs_precomputed.jl | 18 +++++--- test/neighborhood_search.jl | 47 ++++++++++++++++++++ 5 files changed, 153 insertions(+), 25 deletions(-) diff --git a/src/PointNeighbors.jl b/src/PointNeighbors.jl index 58b38687..faa8cefc 100644 --- a/src/PointNeighbors.jl +++ b/src/PointNeighbors.jl @@ -21,7 +21,8 @@ include("nhs_precomputed.jl") include("gpu.jl") export foreach_point_neighbor, foreach_point_neighbor_unsafe, - foreach_neighbor, foreach_neighbor_unsafe + foreach_neighbor, foreach_neighbor_unsafe, + mapreduce_neighbor, mapreduce_neighbor_unsafe export TrivialNeighborhoodSearch, GridNeighborhoodSearch, PrecomputedNeighborhoodSearch export DictionaryCellList, FullGridCellList, SpatialHashingCellList export DynamicVectorOfVectors diff --git a/src/neighborhood_search.jl b/src/neighborhood_search.jl index 528c40de..1dc2f05d 100644 --- a/src/neighborhood_search.jl +++ b/src/neighborhood_search.jl @@ -262,13 +262,17 @@ See [`foreach_neighbor_unsafe`](@ref) for a version that skips all bounds checks point, point_coords, search_radius) end +@inline foreach_neighbor_op(::Any, ::Any) = nothing + # This is a function barrier to prevent the `@inbounds` in `foreach_neighbor` # from propagating into the neighbor loop, which is not safe. @inline function foreach_neighbor(f, neighbor_coords, neighborhood_search::AbstractNeighborhoodSearch, point, point_coords, search_radius) - foreach_neighbor_inner(f, neighbor_coords, neighborhood_search, - point, point_coords, search_radius) + mapreduce_neighbor_inner(f, foreach_neighbor_op, + neighbor_coords, neighborhood_search, + point, point_coords, search_radius, nothing) + return nothing end """ @@ -312,8 +316,67 @@ Note that all these bounds checks are safe to skip if point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)), point) - @inbounds foreach_neighbor_inner(f, neighbor_coords, neighborhood_search, - point, point_coords, search_radius) + @inbounds mapreduce_neighbor_inner(f, foreach_neighbor_op, + neighbor_coords, neighborhood_search, + point, point_coords, search_radius, nothing) + return nothing +end + +""" + mapreduce_neighbor(f, op, system_coords, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, point; + init, search_radius = search_radius(neighborhood_search)) + +Apply `f(i, j, pos_diff, d)` to every neighbor of `point` and reduce the results with +the binary operator `op`, analogous to `mapreduce(f, op, collection)`. + +The keyword argument `init` is required and is returned if `point` has no neighbors. +This method performs the same bounds checks as [`foreach_neighbor`](@ref). +See [`mapreduce_neighbor_unsafe`](@ref) for a version that skips all bounds checks. +""" +@propagate_inbounds function mapreduce_neighbor(f, op, system_coords, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, + point; init, + search_radius = search_radius(neighborhood_search)) + point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point) + + mapreduce_neighbor(f, op, neighbor_coords, neighborhood_search, + point, point_coords, search_radius, init) +end + +# This is a function barrier to prevent the `@inbounds` in `mapreduce_neighbor` +# from propagating into the neighbor loop, which is not safe. +@inline function mapreduce_neighbor(f, op, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, + point, point_coords, search_radius, init) + mapreduce_neighbor_inner(f, op, neighbor_coords, neighborhood_search, + point, point_coords, search_radius, init) +end + +""" + mapreduce_neighbor_unsafe(f, op, system_coords, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, point; + init, search_radius = search_radius(neighborhood_search)) + +Like [`mapreduce_neighbor`](@ref), but skips **all** bounds checks. + +See [`foreach_neighbor_unsafe`](@ref) for details on which bounds checks are skipped +and when it is safe to skip them. + +!!! warning + Use this only when `point` is known to be in bounds of `system_coords` + and when `neighborhood_search` is known to be initialized correctly for + `system_coords` and `neighbor_coords`. +""" +@inline function mapreduce_neighbor_unsafe(f, op, system_coords, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, + point; init, + search_radius = search_radius(neighborhood_search)) + point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)), + point) + + @inbounds mapreduce_neighbor_inner(f, op, neighbor_coords, neighborhood_search, + point, point_coords, search_radius, init) end # This is the generic function that is called for `TrivialNeighborhoodSearch`. @@ -321,11 +384,14 @@ end # performance. `PrecomputedNeighborhoodSearch` can skip the distance check altogether. # Note that calling this function with `@inbounds` is not safe. # See the comments in `foreach_neighbor_unsafe`. -@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords, - neighborhood_search::AbstractNeighborhoodSearch, - point, point_coords, search_radius) +@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, + point, point_coords, + search_radius, init) (; periodic_box) = neighborhood_search + reduced = init + for neighbor in eachneighbor(point_coords, neighborhood_search) neighbor_point_coords = extract_svector(neighbor_coords, Val(ndims(neighborhood_search)), neighbor) @@ -341,11 +407,14 @@ end if distance2 <= search_radius^2 distance = sqrt(distance2) - # Inline to avoid loss of performance - # compared to not using `foreach_point_neighbor`. - @inline f(point, neighbor, pos_diff, distance) + # Inline to avoid loss of performance compared to not using this function + # and unrolling everything. + value = @inline f(point, neighbor, pos_diff, distance) + reduced = @inline op(reduced, value) end end + + return reduced end @inline function compute_periodic_distance(pos_diff, distance2, search_radius, diff --git a/src/nhs_grid.jl b/src/nhs_grid.jl index 8e3c1ff2..6bc5253b 100644 --- a/src/nhs_grid.jl +++ b/src/nhs_grid.jl @@ -510,11 +510,13 @@ end # than looping over `eachneighbor`. # Note that calling this function with `@inbounds` is not safe. # See the comments in `foreach_neighbor_unsafe`. -@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords, - neighborhood_search::GridNeighborhoodSearch, - point, point_coords, search_radius) +@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords, + neighborhood_search::GridNeighborhoodSearch, + point, point_coords, + search_radius, init) (; cell_list, periodic_box) = neighborhood_search cell = cell_coords(point_coords, neighborhood_search) + reduced = init for neighbor_cell_ in neighboring_cells(cell, neighborhood_search) neighbor_cell = Tuple(neighbor_cell_) @@ -545,8 +547,6 @@ end search_radius, periodic_box) if distance2 <= search_radius^2 - distance = sqrt(distance2) - # If this cell has a collision, check if this point belongs to this cell # (only with `SpatialHashingCellList`). if cell_collision && @@ -555,12 +555,17 @@ end continue end - # Inline to avoid loss of performance - # compared to not using `foreach_point_neighbor`. - @inline f(point, neighbor, pos_diff, distance) + distance = sqrt(distance2) + + # Inline to avoid loss of performance compared to not using this function + # and unrolling everything. + value = @inline f(point, neighbor, pos_diff, distance) + reduced = @inline op(reduced, value) end end end + + return reduced end @inline function neighboring_cells(cell, neighborhood_search) diff --git a/src/nhs_precomputed.jl b/src/nhs_precomputed.jl index c7ddceca..94b99838 100644 --- a/src/nhs_precomputed.jl +++ b/src/nhs_precomputed.jl @@ -202,14 +202,17 @@ end # Note that calling this function with `@inbounds` is not safe. # See the comments in `foreach_neighbor_unsafe`. -@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords, - neighborhood_search::PrecomputedNeighborhoodSearch, - point, point_coords, search_radius) +@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords, + neighborhood_search::PrecomputedNeighborhoodSearch, + point, point_coords, + search_radius, init) (; periodic_box, neighbor_lists) = neighborhood_search # Making the following `@inbounds` is not safe because the neighbor list # might not contain `point` if the NHS was not initialized correctly. neighbors = neighbor_lists[point] + reduced = init + for neighbor_ in eachindex(neighbors) neighbor = @inbounds neighbors[neighbor_] @@ -229,10 +232,13 @@ end distance = sqrt(distance2) - # Inline to avoid loss of performance - # compared to not using `foreach_point_neighbor`. - @inline f(point, neighbor, pos_diff, distance) + # Inline to avoid loss of performance compared to not using this function + # and unrolling everything. + value = @inline f(point, neighbor, pos_diff, distance) + reduced = @inline op(reduced, value) end + + return reduced end function copy_neighborhood_search(nhs::PrecomputedNeighborhoodSearch, diff --git a/test/neighborhood_search.jl b/test/neighborhood_search.jl index 3820e11d..104fdd29 100644 --- a/test/neighborhood_search.jl +++ b/test/neighborhood_search.jl @@ -321,6 +321,53 @@ @test sort.(neighbors_manual_unsafe) == neighbors_expected end + + @testset "`mapreduce_neighbor`" begin + neighbor_sums = map(axes(coords, 2)) do point + mapreduce_neighbor(+, coords, coords, nhs, point; + init = 0) do point_, neighbor, + pos_diff, distance + point_ == point || error("incorrect point index") + neighbor + end + end + + @test neighbor_sums == sum.(neighbors_expected) + + @test_throws UndefKeywordError mapreduce_neighbor(+, coords, coords, + nhs, + first(axes(coords, + 2))) do point_, + neighbor, + pos_diff, + distance + neighbor + end + end + + @testset "`mapreduce_neighbor_unsafe`" begin + neighbor_sums = map(axes(coords, 2)) do point + mapreduce_neighbor_unsafe(+, coords, coords, nhs, point; + init = 0) do point_, neighbor, + pos_diff, distance + point_ == point || error("incorrect point index") + neighbor + end + end + + @test neighbor_sums == sum.(neighbors_expected) + + @test_throws UndefKeywordError mapreduce_neighbor_unsafe(+, + coords, coords, + nhs, + first(axes(coords, + 2))) do point_, + neighbor, + pos_diff, + distance + neighbor + end + end end end end From 48bd9b9f8d1d7ea06fe16a7293b56f56d4bf1c47 Mon Sep 17 00:00:00 2001 From: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> Date: Mon, 25 May 2026 17:55:11 +0200 Subject: [PATCH 2/5] Add allocation check --- test/neighborhood_search.jl | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/test/neighborhood_search.jl b/test/neighborhood_search.jl index 104fdd29..0f8e9022 100644 --- a/test/neighborhood_search.jl +++ b/test/neighborhood_search.jl @@ -113,7 +113,7 @@ foreach_point_neighbor(coords, coords, nhs, points = axes(coords, 2)) do point, neighbor, pos_diff, distance - append!(neighbors[point], neighbor) + push!(neighbors[point], neighbor) end # All of these tests are designed to yield the same neighbor lists. @@ -163,7 +163,7 @@ neighbor, pos_diff, distance - append!(neighbors_expected[point], neighbor) + push!(neighbors_expected[point], neighbor) end # Expand the domain by `search_radius`, as we need the neighboring cells of @@ -275,7 +275,7 @@ neighbor, pos_diff, distance - append!(neighbors[point], neighbor) + push!(neighbors[point], neighbor) end @test sort.(neighbors) == neighbors_expected @@ -287,11 +287,19 @@ for point in axes(coords, 2) foreach_neighbor(coords, coords, nhs, point) do point, neighbor, pos_diff, distance - append!(neighbors_manual[point], neighbor) + push!(neighbors_manual[point], neighbor) end end @test sort.(neighbors_manual) == neighbors_expected + + # Test that `foreach_neighbor` does not allocate. + point = first(axes(coords, 2)) + foreach_neighbor((point, neighbor, pos_diff, distance) -> nothing, + coords, coords, nhs, point) + @test @allocated(foreach_neighbor((point, neighbor, pos_diff, + distance) -> nothing, + coords, coords, nhs, point)) == 0 end # Repeat with foreach_point_neighbor_unsafe @@ -302,7 +310,7 @@ neighbor, pos_diff, distance - append!(neighbors_unsafe[point], neighbor) + push!(neighbors_unsafe[point], neighbor) end @test sort.(neighbors_unsafe) == neighbors_expected @@ -315,7 +323,7 @@ foreach_neighbor_unsafe(coords, coords, nhs, point) do point, neighbor, pos_diff, distance - append!(neighbors_manual_unsafe[point], neighbor) + push!(neighbors_manual_unsafe[point], neighbor) end end @@ -334,6 +342,15 @@ @test neighbor_sums == sum.(neighbors_expected) + # Test that `mapreduce_neighbor` does not allocate. + point = first(axes(coords, 2)) + mapreduce_neighbor((point, neighbor, pos_diff, distance) -> neighbor, + +, coords, coords, nhs, point; init = 0) + @test @allocated(mapreduce_neighbor((point, neighbor, pos_diff, + distance) -> neighbor, + +, coords, coords, nhs, point; + init = 0)) == 0 + @test_throws UndefKeywordError mapreduce_neighbor(+, coords, coords, nhs, first(axes(coords, From 8e331bc64d3617e196dd227f5eb0511a73b258ef Mon Sep 17 00:00:00 2001 From: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> Date: Mon, 25 May 2026 18:47:44 +0200 Subject: [PATCH 3/5] Remove allocations --- test/neighborhood_search.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/neighborhood_search.jl b/test/neighborhood_search.jl index 0f8e9022..75ca4512 100644 --- a/test/neighborhood_search.jl +++ b/test/neighborhood_search.jl @@ -295,11 +295,12 @@ # Test that `foreach_neighbor` does not allocate. point = first(axes(coords, 2)) - foreach_neighbor((point, neighbor, pos_diff, distance) -> nothing, - coords, coords, nhs, point) - @test @allocated(foreach_neighbor((point, neighbor, pos_diff, - distance) -> nothing, - coords, coords, nhs, point)) == 0 + function empty_foreach_neighbor(coords, nhs, point) + foreach_neighbor((point, neighbor, pos_diff, distance) -> nothing, + coords, coords, nhs, point) + end + empty_foreach_neighbor(coords, nhs, point) + @test @allocated(empty_foreach_neighbor(coords, nhs, point)) == 0 end # Repeat with foreach_point_neighbor_unsafe @@ -344,12 +345,12 @@ # Test that `mapreduce_neighbor` does not allocate. point = first(axes(coords, 2)) - mapreduce_neighbor((point, neighbor, pos_diff, distance) -> neighbor, - +, coords, coords, nhs, point; init = 0) - @test @allocated(mapreduce_neighbor((point, neighbor, pos_diff, - distance) -> neighbor, - +, coords, coords, nhs, point; - init = 0)) == 0 + function count_neighbors(coords, nhs, point) + mapreduce_neighbor((point, neighbor, pos_diff, distance) -> neighbor, + +, coords, coords, nhs, point; init = 0) + end + count_neighbors(coords, nhs, point) + @test @allocated(count_neighbors(coords, nhs, point)) == 0 @test_throws UndefKeywordError mapreduce_neighbor(+, coords, coords, nhs, From 9fbfee1784bd40d2c35d3a41ec6b3983e91c5541 Mon Sep 17 00:00:00 2001 From: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> Date: Mon, 25 May 2026 18:50:27 +0200 Subject: [PATCH 4/5] Reformat --- test/neighborhood_search.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/neighborhood_search.jl b/test/neighborhood_search.jl index 75ca4512..a63c242f 100644 --- a/test/neighborhood_search.jl +++ b/test/neighborhood_search.jl @@ -346,7 +346,8 @@ # Test that `mapreduce_neighbor` does not allocate. point = first(axes(coords, 2)) function count_neighbors(coords, nhs, point) - mapreduce_neighbor((point, neighbor, pos_diff, distance) -> neighbor, + mapreduce_neighbor((point, neighbor, pos_diff, + distance) -> neighbor, +, coords, coords, nhs, point; init = 0) end count_neighbors(coords, nhs, point) From 110aef6a5629a8e009b99268c36ce548d14e7f55 Mon Sep 17 00:00:00 2001 From: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> Date: Mon, 25 May 2026 19:29:25 +0200 Subject: [PATCH 5/5] Fix tests --- test/neighborhood_search.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/neighborhood_search.jl b/test/neighborhood_search.jl index a63c242f..1fee3f64 100644 --- a/test/neighborhood_search.jl +++ b/test/neighborhood_search.jl @@ -295,12 +295,12 @@ # Test that `foreach_neighbor` does not allocate. point = first(axes(coords, 2)) - function empty_foreach_neighbor(coords, nhs, point) - foreach_neighbor((point, neighbor, pos_diff, distance) -> nothing, - coords, coords, nhs, point) + function allocations_empty_foreach_neighbor(coords, nhs, point) + @allocated(foreach_neighbor((point, neighbor, pos_diff, + distance) -> nothing, + coords, coords, nhs, point)) end - empty_foreach_neighbor(coords, nhs, point) - @test @allocated(empty_foreach_neighbor(coords, nhs, point)) == 0 + @test allocations_empty_foreach_neighbor(coords, nhs, point) == 0 end # Repeat with foreach_point_neighbor_unsafe @@ -345,13 +345,13 @@ # Test that `mapreduce_neighbor` does not allocate. point = first(axes(coords, 2)) - function count_neighbors(coords, nhs, point) - mapreduce_neighbor((point, neighbor, pos_diff, - distance) -> neighbor, - +, coords, coords, nhs, point; init = 0) + function allocations_count_neighbors(coords, nhs, point) + @allocated(mapreduce_neighbor((point, neighbor, pos_diff, + distance) -> neighbor, + +, coords, coords, nhs, point; + init = 0)) end - count_neighbors(coords, nhs, point) - @test @allocated(count_neighbors(coords, nhs, point)) == 0 + @test allocations_count_neighbors(coords, nhs, point) == 0 @test_throws UndefKeywordError mapreduce_neighbor(+, coords, coords, nhs,