diff --git a/src/PointNeighbors.jl b/src/PointNeighbors.jl index 58b38687..faa8cefc 100644 --- a/src/PointNeighbors.jl +++ b/src/PointNeighbors.jl @@ -21,7 +21,8 @@ include("nhs_precomputed.jl") include("gpu.jl") export foreach_point_neighbor, foreach_point_neighbor_unsafe, - foreach_neighbor, foreach_neighbor_unsafe + foreach_neighbor, foreach_neighbor_unsafe, + mapreduce_neighbor, mapreduce_neighbor_unsafe export TrivialNeighborhoodSearch, GridNeighborhoodSearch, PrecomputedNeighborhoodSearch export DictionaryCellList, FullGridCellList, SpatialHashingCellList export DynamicVectorOfVectors diff --git a/src/neighborhood_search.jl b/src/neighborhood_search.jl index 528c40de..1dc2f05d 100644 --- a/src/neighborhood_search.jl +++ b/src/neighborhood_search.jl @@ -262,13 +262,17 @@ See [`foreach_neighbor_unsafe`](@ref) for a version that skips all bounds checks point, point_coords, search_radius) end +@inline foreach_neighbor_op(::Any, ::Any) = nothing + # This is a function barrier to prevent the `@inbounds` in `foreach_neighbor` # from propagating into the neighbor loop, which is not safe. @inline function foreach_neighbor(f, neighbor_coords, neighborhood_search::AbstractNeighborhoodSearch, point, point_coords, search_radius) - foreach_neighbor_inner(f, neighbor_coords, neighborhood_search, - point, point_coords, search_radius) + mapreduce_neighbor_inner(f, foreach_neighbor_op, + neighbor_coords, neighborhood_search, + point, point_coords, search_radius, nothing) + return nothing end """ @@ -312,8 +316,67 @@ Note that all these bounds checks are safe to skip if point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)), point) - @inbounds foreach_neighbor_inner(f, neighbor_coords, neighborhood_search, - point, point_coords, search_radius) + @inbounds mapreduce_neighbor_inner(f, foreach_neighbor_op, + neighbor_coords, neighborhood_search, + point, point_coords, search_radius, nothing) + return nothing +end + +""" + mapreduce_neighbor(f, op, system_coords, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, point; + init, search_radius = search_radius(neighborhood_search)) + +Apply `f(i, j, pos_diff, d)` to every neighbor of `point` and reduce the results with +the binary operator `op`, analogous to `mapreduce(f, op, collection)`. + +The keyword argument `init` is required and is returned if `point` has no neighbors. +This method performs the same bounds checks as [`foreach_neighbor`](@ref). +See [`mapreduce_neighbor_unsafe`](@ref) for a version that skips all bounds checks. +""" +@propagate_inbounds function mapreduce_neighbor(f, op, system_coords, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, + point; init, + search_radius = search_radius(neighborhood_search)) + point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point) + + mapreduce_neighbor(f, op, neighbor_coords, neighborhood_search, + point, point_coords, search_radius, init) +end + +# This is a function barrier to prevent the `@inbounds` in `mapreduce_neighbor` +# from propagating into the neighbor loop, which is not safe. +@inline function mapreduce_neighbor(f, op, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, + point, point_coords, search_radius, init) + mapreduce_neighbor_inner(f, op, neighbor_coords, neighborhood_search, + point, point_coords, search_radius, init) +end + +""" + mapreduce_neighbor_unsafe(f, op, system_coords, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, point; + init, search_radius = search_radius(neighborhood_search)) + +Like [`mapreduce_neighbor`](@ref), but skips **all** bounds checks. + +See [`foreach_neighbor_unsafe`](@ref) for details on which bounds checks are skipped +and when it is safe to skip them. + +!!! warning + Use this only when `point` is known to be in bounds of `system_coords` + and when `neighborhood_search` is known to be initialized correctly for + `system_coords` and `neighbor_coords`. +""" +@inline function mapreduce_neighbor_unsafe(f, op, system_coords, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, + point; init, + search_radius = search_radius(neighborhood_search)) + point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)), + point) + + @inbounds mapreduce_neighbor_inner(f, op, neighbor_coords, neighborhood_search, + point, point_coords, search_radius, init) end # This is the generic function that is called for `TrivialNeighborhoodSearch`. @@ -321,11 +384,14 @@ end # performance. `PrecomputedNeighborhoodSearch` can skip the distance check altogether. # Note that calling this function with `@inbounds` is not safe. # See the comments in `foreach_neighbor_unsafe`. -@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords, - neighborhood_search::AbstractNeighborhoodSearch, - point, point_coords, search_radius) +@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords, + neighborhood_search::AbstractNeighborhoodSearch, + point, point_coords, + search_radius, init) (; periodic_box) = neighborhood_search + reduced = init + for neighbor in eachneighbor(point_coords, neighborhood_search) neighbor_point_coords = extract_svector(neighbor_coords, Val(ndims(neighborhood_search)), neighbor) @@ -341,11 +407,14 @@ end if distance2 <= search_radius^2 distance = sqrt(distance2) - # Inline to avoid loss of performance - # compared to not using `foreach_point_neighbor`. - @inline f(point, neighbor, pos_diff, distance) + # Inline to avoid loss of performance compared to not using this function + # and unrolling everything. + value = @inline f(point, neighbor, pos_diff, distance) + reduced = @inline op(reduced, value) end end + + return reduced end @inline function compute_periodic_distance(pos_diff, distance2, search_radius, diff --git a/src/nhs_grid.jl b/src/nhs_grid.jl index 8e3c1ff2..6bc5253b 100644 --- a/src/nhs_grid.jl +++ b/src/nhs_grid.jl @@ -510,11 +510,13 @@ end # than looping over `eachneighbor`. # Note that calling this function with `@inbounds` is not safe. # See the comments in `foreach_neighbor_unsafe`. -@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords, - neighborhood_search::GridNeighborhoodSearch, - point, point_coords, search_radius) +@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords, + neighborhood_search::GridNeighborhoodSearch, + point, point_coords, + search_radius, init) (; cell_list, periodic_box) = neighborhood_search cell = cell_coords(point_coords, neighborhood_search) + reduced = init for neighbor_cell_ in neighboring_cells(cell, neighborhood_search) neighbor_cell = Tuple(neighbor_cell_) @@ -545,8 +547,6 @@ end search_radius, periodic_box) if distance2 <= search_radius^2 - distance = sqrt(distance2) - # If this cell has a collision, check if this point belongs to this cell # (only with `SpatialHashingCellList`). if cell_collision && @@ -555,12 +555,17 @@ end continue end - # Inline to avoid loss of performance - # compared to not using `foreach_point_neighbor`. - @inline f(point, neighbor, pos_diff, distance) + distance = sqrt(distance2) + + # Inline to avoid loss of performance compared to not using this function + # and unrolling everything. + value = @inline f(point, neighbor, pos_diff, distance) + reduced = @inline op(reduced, value) end end end + + return reduced end @inline function neighboring_cells(cell, neighborhood_search) diff --git a/src/nhs_precomputed.jl b/src/nhs_precomputed.jl index c7ddceca..94b99838 100644 --- a/src/nhs_precomputed.jl +++ b/src/nhs_precomputed.jl @@ -202,14 +202,17 @@ end # Note that calling this function with `@inbounds` is not safe. # See the comments in `foreach_neighbor_unsafe`. -@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords, - neighborhood_search::PrecomputedNeighborhoodSearch, - point, point_coords, search_radius) +@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords, + neighborhood_search::PrecomputedNeighborhoodSearch, + point, point_coords, + search_radius, init) (; periodic_box, neighbor_lists) = neighborhood_search # Making the following `@inbounds` is not safe because the neighbor list # might not contain `point` if the NHS was not initialized correctly. neighbors = neighbor_lists[point] + reduced = init + for neighbor_ in eachindex(neighbors) neighbor = @inbounds neighbors[neighbor_] @@ -229,10 +232,13 @@ end distance = sqrt(distance2) - # Inline to avoid loss of performance - # compared to not using `foreach_point_neighbor`. - @inline f(point, neighbor, pos_diff, distance) + # Inline to avoid loss of performance compared to not using this function + # and unrolling everything. + value = @inline f(point, neighbor, pos_diff, distance) + reduced = @inline op(reduced, value) end + + return reduced end function copy_neighborhood_search(nhs::PrecomputedNeighborhoodSearch, diff --git a/test/neighborhood_search.jl b/test/neighborhood_search.jl index 3820e11d..1fee3f64 100644 --- a/test/neighborhood_search.jl +++ b/test/neighborhood_search.jl @@ -113,7 +113,7 @@ foreach_point_neighbor(coords, coords, nhs, points = axes(coords, 2)) do point, neighbor, pos_diff, distance - append!(neighbors[point], neighbor) + push!(neighbors[point], neighbor) end # All of these tests are designed to yield the same neighbor lists. @@ -163,7 +163,7 @@ neighbor, pos_diff, distance - append!(neighbors_expected[point], neighbor) + push!(neighbors_expected[point], neighbor) end # Expand the domain by `search_radius`, as we need the neighboring cells of @@ -275,7 +275,7 @@ neighbor, pos_diff, distance - append!(neighbors[point], neighbor) + push!(neighbors[point], neighbor) end @test sort.(neighbors) == neighbors_expected @@ -287,11 +287,20 @@ for point in axes(coords, 2) foreach_neighbor(coords, coords, nhs, point) do point, neighbor, pos_diff, distance - append!(neighbors_manual[point], neighbor) + push!(neighbors_manual[point], neighbor) end end @test sort.(neighbors_manual) == neighbors_expected + + # Test that `foreach_neighbor` does not allocate. + point = first(axes(coords, 2)) + function allocations_empty_foreach_neighbor(coords, nhs, point) + @allocated(foreach_neighbor((point, neighbor, pos_diff, + distance) -> nothing, + coords, coords, nhs, point)) + end + @test allocations_empty_foreach_neighbor(coords, nhs, point) == 0 end # Repeat with foreach_point_neighbor_unsafe @@ -302,7 +311,7 @@ neighbor, pos_diff, distance - append!(neighbors_unsafe[point], neighbor) + push!(neighbors_unsafe[point], neighbor) end @test sort.(neighbors_unsafe) == neighbors_expected @@ -315,12 +324,69 @@ foreach_neighbor_unsafe(coords, coords, nhs, point) do point, neighbor, pos_diff, distance - append!(neighbors_manual_unsafe[point], neighbor) + push!(neighbors_manual_unsafe[point], neighbor) end end @test sort.(neighbors_manual_unsafe) == neighbors_expected end + + @testset "`mapreduce_neighbor`" begin + neighbor_sums = map(axes(coords, 2)) do point + mapreduce_neighbor(+, coords, coords, nhs, point; + init = 0) do point_, neighbor, + pos_diff, distance + point_ == point || error("incorrect point index") + neighbor + end + end + + @test neighbor_sums == sum.(neighbors_expected) + + # Test that `mapreduce_neighbor` does not allocate. + point = first(axes(coords, 2)) + function allocations_count_neighbors(coords, nhs, point) + @allocated(mapreduce_neighbor((point, neighbor, pos_diff, + distance) -> neighbor, + +, coords, coords, nhs, point; + init = 0)) + end + @test allocations_count_neighbors(coords, nhs, point) == 0 + + @test_throws UndefKeywordError mapreduce_neighbor(+, coords, coords, + nhs, + first(axes(coords, + 2))) do point_, + neighbor, + pos_diff, + distance + neighbor + end + end + + @testset "`mapreduce_neighbor_unsafe`" begin + neighbor_sums = map(axes(coords, 2)) do point + mapreduce_neighbor_unsafe(+, coords, coords, nhs, point; + init = 0) do point_, neighbor, + pos_diff, distance + point_ == point || error("incorrect point index") + neighbor + end + end + + @test neighbor_sums == sum.(neighbors_expected) + + @test_throws UndefKeywordError mapreduce_neighbor_unsafe(+, + coords, coords, + nhs, + first(axes(coords, + 2))) do point_, + neighbor, + pos_diff, + distance + neighbor + end + end end end end