Skip to content

Conversation

@candytaco
Copy link

This pull request adds the a pairwise distance matrix-based neighbor-finding algorithm for EDM, and the option to switch between KDTree and pairwise distance matrices. The pairwise distance matrices allow for reuse and significantly improves runtimes during repeated queries, e.g. CCM.

API additions for neighbor search:

  • Added a neighbor_algorithm parameter to the main API functions allowing users to select between different neighbor-finding algorithms (currently 'kdtree' and 'pdist').
  • The neighbor_algorithm parameter is passed to the EDM classes, and the EDM base class uses it to switch between algorithms.
  • The default value for neighbor_algorithm is kdtree to keep internal calculations consistent with previous versions, except in CCM, which is pdist for performance speedups.

Performance improvements:

  • Refactored the CCM class to use different logic for the two neighbor-finding algorithms
    • for kdtree the logic is kept the same
    • for pdist, on each sample of each libSize, the distance matrix is copied, and only the sampled library points are kept. This is done by setting the pairwise distances for all other points to np.inf. This modified distance matrix is then re-queried for nearest neighbors.
  • The pdist implementation offers a
    • ~25x speedup over the KDTree implementation on 2x Xeon 5318Y (48 cores/92 threads),
image
  • and a ~3x speedup over KDTree on a Xeon E3-1245v2 (4 cores/8 threads)
image
  • For Simplex, pdist at least matches KDtree performance (Xeon E3-1245v2) and can provide a ~6x speedup (2x Xeon 5318Y)

Codebase changes:

  • Added neighbor finder classes in NeighborFinder.py to be used by EDM classes. They provide a consistent API for both KDTree and pairwise distance-based neighbor finding. The NeighborFinderBase class specifies an interface for adding additional future algorithms. The KDTreeNeighborFinder and PairwiseDistanceNeighborFinder classes inherit from the base class.
  • A neighbor_finder field for the EDM base class that is an object inheriting from NeighborFinderBase that implement the actual neighbor finding algorithm.
  • The FindNeighbors method is moved into the EDM class file, and some of its internal logic is factored out into separate methods and properties:
    • EDM.check_lib_valid() for checking that lib indices are valid
    • EDM.excludionRadius_knn To check if exclusion radius adjustment is needed
    • EDM.knn_ for getting actual number of nearest neighbors to query
    • EDM.map_knn_indices_to_data() to map indices returned by neighbor finder to indices in data
  • Breaking up the method allows for logic, particularly that factored into map_knn_indices_to_data, to be reused during repeated nearest-neighbor queries without needing to go through the entire instantiation of the neighbor-finding algorithm.

Misc

  • These changes passes all the tests in pyEDM except for:
    • Specifying pdist as the algorithm will fail test_simplex10. However, the specific samples that cause the algorithm to differ from the expected predictions have values of <0, 0, 0>, which matches many lib entries that are also <0, 0, 0>. There are more of these degenerate neighbors than k, and the neighbors indices returned by the algorithms are instance-specific to the particular KDTree object or the outcomes of quicksort for np.argsort, so nearest-neighbor selection is not guaranteed to be deterministic in this case.
    • In test_ccm3, the comparison of two arrays separately rounded to 4 decimal points fails when np.allclose with an absolute tolerance of 1e-4 would consider them equal, likely because one value gets rounded up and the other down.

Copilot AI review requested due to automatic review settings January 27, 2026 18:55
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an optional pairwise-distance (“pdist”) neighbor-finding implementation alongside the existing KDTree approach, with wiring through the public API and EDM classes to enable reuse and speedups (notably for CCM).

Changes:

  • Introduces NeighborFinder abstractions for KDTree- and pairwise-distance-based kNN queries.
  • Refactors neighbor search into EDM.FindNeighbors() with shared helper methods for kNN/exclusion-radius handling and index mapping.
  • Updates CCM to support algorithm-specific neighbor-search logic (defaulting CCM to pdist) and relaxes one CCM test assertion to use tolerance-based comparison.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
src/pyEDM/API.py Adds neighbor_algorithm parameter to public APIs and passes it into classes.
src/pyEDM/EDM.py Moves/refactors neighbor-finding into EDM; selects KDTree vs pdist via neighbor_algorithm.
src/pyEDM/NeighborFinder.py Adds KDTree and pairwise-distance neighbor finder implementations behind a common interface.
src/pyEDM/Simplex.py Plumbs neighbor_algorithm into EDM base; refactors target-value gathering for projection.
src/pyEDM/SMap.py Plumbs neighbor_algorithm into EDM base.
src/pyEDM/CCM.py Defaults CCM to pdist and adds distance-matrix reuse logic for repeated subsampling.
src/pyEDM/tests/tests.py Updates CCM nan test to use allclose(..., atol=1e-4) instead of rounding equality.
src/pyEDM/Neighbors.py Removes the legacy standalone KDTree neighbor-finding implementation.
Comments suppressed due to low confidence (1)

src/pyEDM/CCM.py:251

  • Same shape/broadcasting issue as in Simplex.Project(): S.targetVec[knn_neighbors_Tp].squeeze() can drop the neighbor axis when knn == 1, which can corrupt weights * libTargetValues and projections. Ensure libTargetValues retains (N_pred, k) shape (e.g., by selecting [:, 0] from targetVec rather than squeeze()).
                knn_neighbors_Tp = neighbor_indices + self.Tp      # Npred x k

                libTargetValues = S.targetVec[knn_neighbors_Tp].squeeze()
                # Code from Simplex:Project ----------------------------------

                # Projection is average of weighted knn library target values
                projection_ = sum( weights * libTargetValues,
                                   axis = 1) / weightRowSum

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

for j in range( knn_neighbors_Tp.shape[1] ) : # for each column j of k
libTargetValues[ :, j ][ :, None ] = \
self.targetVec[ knn_neighbors_Tp[ :, j ] ]
libTargetValues = self.targetVec[knn_neighbors_Tp].squeeze()
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

libTargetValues = self.targetVec[knn_neighbors_Tp].squeeze() can collapse the neighbor dimension when knn == 1, producing shape (N_pred,) instead of (N_pred, 1) and causing incorrect broadcasting in weights * libTargetValues (e.g., (N,1)*(N,) -> (N,N)). Index the last dimension explicitly (e.g., select column 0) or otherwise ensure libTargetValues is always 2D with shape (N_pred, k).

Suggested change
libTargetValues = self.targetVec[knn_neighbors_Tp].squeeze()
libTargetValues = self.targetVec[knn_neighbors_Tp]
# Ensure libTargetValues remains 2D (N x k), even when k == 1
if libTargetValues.ndim == 1:
libTargetValues = libTargetValues.reshape(-1, 1)

Copilot uses AI. Check for mistakes.
d = S.neighbor_finder.distanceMatrix.copy()
mask = np.ones(d.shape[0], dtype = bool)
mask[rng_i] = False
d[mask, :] = np.inf # artificially make all the other ones far awa
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "far awa" should be "far away".

Suggested change
d[mask, :] = np.inf # artificially make all the other ones far awa
d[mask, :] = np.inf # artificially make all the other ones far away

Copilot uses AI. Check for mistakes.
Comment on lines +8 to +20
class NeighborFinderBase(AbstractBaseClass):
"""
Interface for describing a class to find nearest neighbors
"""

def __init__(self,
data: np.ndarray):
"""
Constructor
:param data: data in the shape of [samples, dimensions]
"""
self.data = data

Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file uses tab characters for indentation. The rest of the codebase appears to use 4-space indentation (e.g., src/pyEDM/EDM.py), and tabs increase the risk of TabError if mixed with spaces and can break formatting/linting. Convert indentation to spaces throughout this module.

Copilot uses AI. Check for mistakes.

@staticmethod
def find_neighbors(distances: np.ndarray, k: int):
neighbors = np.argsort(distances, axis = 0)[:k, :]
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.argsort(distances, axis=0) uses an unstable sort by default, so the ordering of equal-distance neighbors can vary across NumPy versions/platforms, making results non-reproducible (especially for degenerate/tied distances). Use a stable sort (e.g., kind="stable"/mergesort) and/or an explicit deterministic tie-breaker to keep pdist behavior reproducible.

Suggested change
neighbors = np.argsort(distances, axis = 0)[:k, :]
neighbors = np.argsort(distances, axis=0, kind="stable")[:k, :]

Copilot uses AI. Check for mistakes.
:param additional_distance:
:return:
"""
out = PairwiseDistanceNeighborFinder(None, None)
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update() constructs PairwiseDistanceNeighborFinder(None, None), leaving out.data as None. If the returned instance is later queried with a new x, distance.cdist(self.data, x, ...) will crash. Preserve the original data in the updated instance (or make query() validate self.data when x is provided).

Suggested change
out = PairwiseDistanceNeighborFinder(None, None)
out = PairwiseDistanceNeighborFinder(self.data, None)

Copilot uses AI. Check for mistakes.
Comment on lines 107 to +110
showPlot = False,
ignoreNan = True,
returnObject = False ):
returnObject = False,
neighbor_algorithm = 'kdtree'):
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new public neighbor_algorithm option is added for Simplex/SMap, but the test suite does not appear to exercise neighbor_algorithm='pdist' for these APIs. Add a unit/integration test that runs EDM.Simplex(..., neighbor_algorithm='pdist') and EDM.SMap(..., neighbor_algorithm='pdist') on a small non-degenerate dataset (no tied neighbors) to validate the new code path and prevent regressions.

Copilot uses AI. Check for mistakes.
src/pyEDM/EDM.py Outdated
compact_nodes = True,
balanced_tree = True)
elif self.neighbor_algorithm == 'pdist':
self.neighbor_finder = PairwiseDistanceNeighborFinder(self.Embedding.iloc[self.lib_i, :].to_numpy())
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the 'pdist' branch, train is already computed (train = self.Embedding.iloc[self.lib_i, :].to_numpy()), but the code recomputes the same .to_numpy() again when constructing PairwiseDistanceNeighborFinder. Reuse train to avoid an extra conversion/allocation (this can be significant for large embeddings).

Suggested change
self.neighbor_finder = PairwiseDistanceNeighborFinder(self.Embedding.iloc[self.lib_i, :].to_numpy())
self.neighbor_finder = PairwiseDistanceNeighborFinder(train)

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +24
def query(self,
x: np.ndarray,
k: int = 1,
**kwargs) -> Tuple[Union[float, np.ndarray], Union[int, np.ndarray]]:
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overridden method signature does not match call, where it is passed too many arguments. Overriding method method KDTreeNeighborFinder.query matches the call.
Overridden method signature does not match call, where it is passed an argument named 'eps'. Overriding method method KDTreeNeighborFinder.query matches the call.
Overridden method signature does not match call, where it is passed an argument named 'p'. Overriding method method KDTreeNeighborFinder.query matches the call.
Overridden method signature does not match call, where it is passed an argument named 'workers'. Overriding method method KDTreeNeighborFinder.query matches the call.

Copilot uses AI. Check for mistakes.
@SoftwareLiteracy
Copy link
Collaborator

@candytaco : Thank you for this extensive contribution. I have been wanting to do this for a long time to speed up CCM ; ) I hope to be able to look closely at it in the next week, I would like to test it on some heavy CCM loads.

Much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants