Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/PointNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ include("nhs_precomputed.jl")
include("gpu.jl")

export foreach_point_neighbor, foreach_point_neighbor_unsafe,
foreach_neighbor, foreach_neighbor_unsafe
foreach_neighbor, foreach_neighbor_unsafe,
mapreduce_neighbor, mapreduce_neighbor_unsafe
export TrivialNeighborhoodSearch, GridNeighborhoodSearch, PrecomputedNeighborhoodSearch
export DictionaryCellList, FullGridCellList, SpatialHashingCellList
export DynamicVectorOfVectors
Expand Down
89 changes: 79 additions & 10 deletions src/neighborhood_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,17 @@ See [`foreach_neighbor_unsafe`](@ref) for a version that skips all bounds checks
point, point_coords, search_radius)
end

@inline foreach_neighbor_op(::Any, ::Any) = nothing

# This is a function barrier to prevent the `@inbounds` in `foreach_neighbor`
# from propagating into the neighbor loop, which is not safe.
@inline function foreach_neighbor(f, neighbor_coords,
neighborhood_search::AbstractNeighborhoodSearch,
point, point_coords, search_radius)
foreach_neighbor_inner(f, neighbor_coords, neighborhood_search,
point, point_coords, search_radius)
mapreduce_neighbor_inner(f, foreach_neighbor_op,
neighbor_coords, neighborhood_search,
point, point_coords, search_radius, nothing)
return nothing
Comment thread
efaulhaber marked this conversation as resolved.
end

"""
Expand Down Expand Up @@ -312,20 +316,82 @@ Note that all these bounds checks are safe to skip if
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
point)

@inbounds foreach_neighbor_inner(f, neighbor_coords, neighborhood_search,
point, point_coords, search_radius)
@inbounds mapreduce_neighbor_inner(f, foreach_neighbor_op,
neighbor_coords, neighborhood_search,
point, point_coords, search_radius, nothing)
return nothing
end

"""
mapreduce_neighbor(f, op, system_coords, neighbor_coords,
neighborhood_search::AbstractNeighborhoodSearch, point;
init, search_radius = search_radius(neighborhood_search))

Apply `f(i, j, pos_diff, d)` to every neighbor of `point` and reduce the results with
the binary operator `op`, analogous to `mapreduce(f, op, collection)`.

The keyword argument `init` is required and is returned if `point` has no neighbors.
This method performs the same bounds checks as [`foreach_neighbor`](@ref).
See [`mapreduce_neighbor_unsafe`](@ref) for a version that skips all bounds checks.
"""
@propagate_inbounds function mapreduce_neighbor(f, op, system_coords, neighbor_coords,
neighborhood_search::AbstractNeighborhoodSearch,
point; init,
search_radius = search_radius(neighborhood_search))
point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)

mapreduce_neighbor(f, op, neighbor_coords, neighborhood_search,
point, point_coords, search_radius, init)
end

# This is a function barrier to prevent the `@inbounds` in `mapreduce_neighbor`
# from propagating into the neighbor loop, which is not safe.
@inline function mapreduce_neighbor(f, op, neighbor_coords,
neighborhood_search::AbstractNeighborhoodSearch,
point, point_coords, search_radius, init)
Comment thread
efaulhaber marked this conversation as resolved.
mapreduce_neighbor_inner(f, op, neighbor_coords, neighborhood_search,
point, point_coords, search_radius, init)
end

"""
mapreduce_neighbor_unsafe(f, op, system_coords, neighbor_coords,
neighborhood_search::AbstractNeighborhoodSearch, point;
init, search_radius = search_radius(neighborhood_search))

Like [`mapreduce_neighbor`](@ref), but skips **all** bounds checks.

See [`foreach_neighbor_unsafe`](@ref) for details on which bounds checks are skipped
and when it is safe to skip them.

!!! warning
Use this only when `point` is known to be in bounds of `system_coords`
and when `neighborhood_search` is known to be initialized correctly for
`system_coords` and `neighbor_coords`.
"""
@inline function mapreduce_neighbor_unsafe(f, op, system_coords, neighbor_coords,
neighborhood_search::AbstractNeighborhoodSearch,
point; init,
search_radius = search_radius(neighborhood_search))
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
point)

@inbounds mapreduce_neighbor_inner(f, op, neighbor_coords, neighborhood_search,
point, point_coords, search_radius, init)
end

# This is the generic function that is called for `TrivialNeighborhoodSearch`.
# For `GridNeighborhoodSearch`, a specialized function is used for slightly better
# performance. `PrecomputedNeighborhoodSearch` can skip the distance check altogether.
# Note that calling this function with `@inbounds` is not safe.
# See the comments in `foreach_neighbor_unsafe`.
@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords,
neighborhood_search::AbstractNeighborhoodSearch,
point, point_coords, search_radius)
@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords,
neighborhood_search::AbstractNeighborhoodSearch,
point, point_coords,
search_radius, init)
(; periodic_box) = neighborhood_search

reduced = init

for neighbor in eachneighbor(point_coords, neighborhood_search)
neighbor_point_coords = extract_svector(neighbor_coords,
Val(ndims(neighborhood_search)), neighbor)
Expand All @@ -341,11 +407,14 @@ end
if distance2 <= search_radius^2
distance = sqrt(distance2)

# Inline to avoid loss of performance
# compared to not using `foreach_point_neighbor`.
@inline f(point, neighbor, pos_diff, distance)
# Inline to avoid loss of performance compared to not using this function
# and unrolling everything.
value = @inline f(point, neighbor, pos_diff, distance)
reduced = @inline op(reduced, value)
end
end

return reduced
end

@inline function compute_periodic_distance(pos_diff, distance2, search_radius,
Expand Down
21 changes: 13 additions & 8 deletions src/nhs_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,13 @@ end
# than looping over `eachneighbor`.
# Note that calling this function with `@inbounds` is not safe.
# See the comments in `foreach_neighbor_unsafe`.
@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords,
neighborhood_search::GridNeighborhoodSearch,
point, point_coords, search_radius)
@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords,
neighborhood_search::GridNeighborhoodSearch,
point, point_coords,
search_radius, init)
(; cell_list, periodic_box) = neighborhood_search
cell = cell_coords(point_coords, neighborhood_search)
reduced = init

for neighbor_cell_ in neighboring_cells(cell, neighborhood_search)
neighbor_cell = Tuple(neighbor_cell_)
Expand Down Expand Up @@ -545,8 +547,6 @@ end
search_radius, periodic_box)

if distance2 <= search_radius^2
distance = sqrt(distance2)

# If this cell has a collision, check if this point belongs to this cell
# (only with `SpatialHashingCellList`).
if cell_collision &&
Expand All @@ -555,12 +555,17 @@ end
continue
end

# Inline to avoid loss of performance
# compared to not using `foreach_point_neighbor`.
@inline f(point, neighbor, pos_diff, distance)
distance = sqrt(distance2)

# Inline to avoid loss of performance compared to not using this function
# and unrolling everything.
value = @inline f(point, neighbor, pos_diff, distance)
reduced = @inline op(reduced, value)
end
end
end

return reduced
end

@inline function neighboring_cells(cell, neighborhood_search)
Expand Down
18 changes: 12 additions & 6 deletions src/nhs_precomputed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,17 @@ end

# Note that calling this function with `@inbounds` is not safe.
# See the comments in `foreach_neighbor_unsafe`.
@propagate_inbounds function foreach_neighbor_inner(f, neighbor_coords,
neighborhood_search::PrecomputedNeighborhoodSearch,
point, point_coords, search_radius)
@propagate_inbounds function mapreduce_neighbor_inner(f, op, neighbor_coords,
neighborhood_search::PrecomputedNeighborhoodSearch,
point, point_coords,
search_radius, init)
(; periodic_box, neighbor_lists) = neighborhood_search

# Making the following `@inbounds` is not safe because the neighbor list
# might not contain `point` if the NHS was not initialized correctly.
neighbors = neighbor_lists[point]
reduced = init

for neighbor_ in eachindex(neighbors)
neighbor = @inbounds neighbors[neighbor_]

Expand All @@ -229,10 +232,13 @@ end

distance = sqrt(distance2)

# Inline to avoid loss of performance
# compared to not using `foreach_point_neighbor`.
@inline f(point, neighbor, pos_diff, distance)
# Inline to avoid loss of performance compared to not using this function
# and unrolling everything.
value = @inline f(point, neighbor, pos_diff, distance)
reduced = @inline op(reduced, value)
end

return reduced
end

function copy_neighborhood_search(nhs::PrecomputedNeighborhoodSearch,
Expand Down
78 changes: 72 additions & 6 deletions test/neighborhood_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
foreach_point_neighbor(coords, coords, nhs,
points = axes(coords, 2)) do point, neighbor,
pos_diff, distance
append!(neighbors[point], neighbor)
push!(neighbors[point], neighbor)
end

# All of these tests are designed to yield the same neighbor lists.
Expand Down Expand Up @@ -163,7 +163,7 @@
neighbor,
pos_diff,
distance
append!(neighbors_expected[point], neighbor)
push!(neighbors_expected[point], neighbor)
end

# Expand the domain by `search_radius`, as we need the neighboring cells of
Expand Down Expand Up @@ -275,7 +275,7 @@
neighbor,
pos_diff,
distance
append!(neighbors[point], neighbor)
push!(neighbors[point], neighbor)
end

@test sort.(neighbors) == neighbors_expected
Expand All @@ -287,11 +287,20 @@
for point in axes(coords, 2)
foreach_neighbor(coords, coords, nhs,
point) do point, neighbor, pos_diff, distance
append!(neighbors_manual[point], neighbor)
push!(neighbors_manual[point], neighbor)
end
end

@test sort.(neighbors_manual) == neighbors_expected

# Test that `foreach_neighbor` does not allocate.
point = first(axes(coords, 2))
function allocations_empty_foreach_neighbor(coords, nhs, point)
@allocated(foreach_neighbor((point, neighbor, pos_diff,
distance) -> nothing,
coords, coords, nhs, point))
end
@test allocations_empty_foreach_neighbor(coords, nhs, point) == 0
end

# Repeat with foreach_point_neighbor_unsafe
Expand All @@ -302,7 +311,7 @@
neighbor,
pos_diff,
distance
append!(neighbors_unsafe[point], neighbor)
push!(neighbors_unsafe[point], neighbor)
end

@test sort.(neighbors_unsafe) == neighbors_expected
Expand All @@ -315,12 +324,69 @@
foreach_neighbor_unsafe(coords, coords, nhs,
point) do point, neighbor,
pos_diff, distance
append!(neighbors_manual_unsafe[point], neighbor)
push!(neighbors_manual_unsafe[point], neighbor)
end
end

@test sort.(neighbors_manual_unsafe) == neighbors_expected
end

@testset "`mapreduce_neighbor`" begin
neighbor_sums = map(axes(coords, 2)) do point
mapreduce_neighbor(+, coords, coords, nhs, point;
init = 0) do point_, neighbor,
pos_diff, distance
point_ == point || error("incorrect point index")
neighbor
end
end

@test neighbor_sums == sum.(neighbors_expected)

# Test that `mapreduce_neighbor` does not allocate.
point = first(axes(coords, 2))
function allocations_count_neighbors(coords, nhs, point)
@allocated(mapreduce_neighbor((point, neighbor, pos_diff,
distance) -> neighbor,
+, coords, coords, nhs, point;
init = 0))
end
@test allocations_count_neighbors(coords, nhs, point) == 0

@test_throws UndefKeywordError mapreduce_neighbor(+, coords, coords,
nhs,
first(axes(coords,
2))) do point_,
neighbor,
pos_diff,
distance
neighbor
end
end

@testset "`mapreduce_neighbor_unsafe`" begin
neighbor_sums = map(axes(coords, 2)) do point
mapreduce_neighbor_unsafe(+, coords, coords, nhs, point;
init = 0) do point_, neighbor,
pos_diff, distance
point_ == point || error("incorrect point index")
neighbor
end
end

@test neighbor_sums == sum.(neighbors_expected)

@test_throws UndefKeywordError mapreduce_neighbor_unsafe(+,
coords, coords,
nhs,
first(axes(coords,
2))) do point_,
neighbor,
pos_diff,
distance
neighbor
end
end
end
end
end
Expand Down
Loading