-
Notifications
You must be signed in to change notification settings - Fork 34
Add pairwise distance-based neighbor finding in addition to KDTree #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…nce-based neighbor finding
…omparisons of two rounded arrays
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds an optional pairwise-distance (“pdist”) neighbor-finding implementation alongside the existing KDTree approach, with wiring through the public API and EDM classes to enable reuse and speedups (notably for CCM).
Changes:
- Introduces
NeighborFinderabstractions for KDTree- and pairwise-distance-based kNN queries. - Refactors neighbor search into
EDM.FindNeighbors()with shared helper methods for kNN/exclusion-radius handling and index mapping. - Updates CCM to support algorithm-specific neighbor-search logic (defaulting CCM to
pdist) and relaxes one CCM test assertion to use tolerance-based comparison.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| src/pyEDM/API.py | Adds neighbor_algorithm parameter to public APIs and passes it into classes. |
| src/pyEDM/EDM.py | Moves/refactors neighbor-finding into EDM; selects KDTree vs pdist via neighbor_algorithm. |
| src/pyEDM/NeighborFinder.py | Adds KDTree and pairwise-distance neighbor finder implementations behind a common interface. |
| src/pyEDM/Simplex.py | Plumbs neighbor_algorithm into EDM base; refactors target-value gathering for projection. |
| src/pyEDM/SMap.py | Plumbs neighbor_algorithm into EDM base. |
| src/pyEDM/CCM.py | Defaults CCM to pdist and adds distance-matrix reuse logic for repeated subsampling. |
| src/pyEDM/tests/tests.py | Updates CCM nan test to use allclose(..., atol=1e-4) instead of rounding equality. |
| src/pyEDM/Neighbors.py | Removes the legacy standalone KDTree neighbor-finding implementation. |
Comments suppressed due to low confidence (1)
src/pyEDM/CCM.py:251
- Same shape/broadcasting issue as in
Simplex.Project():S.targetVec[knn_neighbors_Tp].squeeze()can drop the neighbor axis whenknn == 1, which can corruptweights * libTargetValuesand projections. EnsurelibTargetValuesretains(N_pred, k)shape (e.g., by selecting[:, 0]fromtargetVecrather thansqueeze()).
knn_neighbors_Tp = neighbor_indices + self.Tp # Npred x k
libTargetValues = S.targetVec[knn_neighbors_Tp].squeeze()
# Code from Simplex:Project ----------------------------------
# Projection is average of weighted knn library target values
projection_ = sum( weights * libTargetValues,
axis = 1) / weightRowSum
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for j in range( knn_neighbors_Tp.shape[1] ) : # for each column j of k | ||
| libTargetValues[ :, j ][ :, None ] = \ | ||
| self.targetVec[ knn_neighbors_Tp[ :, j ] ] | ||
| libTargetValues = self.targetVec[knn_neighbors_Tp].squeeze() |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
libTargetValues = self.targetVec[knn_neighbors_Tp].squeeze() can collapse the neighbor dimension when knn == 1, producing shape (N_pred,) instead of (N_pred, 1) and causing incorrect broadcasting in weights * libTargetValues (e.g., (N,1)*(N,) -> (N,N)). Index the last dimension explicitly (e.g., select column 0) or otherwise ensure libTargetValues is always 2D with shape (N_pred, k).
| libTargetValues = self.targetVec[knn_neighbors_Tp].squeeze() | |
| libTargetValues = self.targetVec[knn_neighbors_Tp] | |
| # Ensure libTargetValues remains 2D (N x k), even when k == 1 | |
| if libTargetValues.ndim == 1: | |
| libTargetValues = libTargetValues.reshape(-1, 1) |
| d = S.neighbor_finder.distanceMatrix.copy() | ||
| mask = np.ones(d.shape[0], dtype = bool) | ||
| mask[rng_i] = False | ||
| d[mask, :] = np.inf # artificially make all the other ones far awa |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in comment: "far awa" should be "far away".
| d[mask, :] = np.inf # artificially make all the other ones far awa | |
| d[mask, :] = np.inf # artificially make all the other ones far away |
| class NeighborFinderBase(AbstractBaseClass): | ||
| """ | ||
| Interface for describing a class to find nearest neighbors | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| data: np.ndarray): | ||
| """ | ||
| Constructor | ||
| :param data: data in the shape of [samples, dimensions] | ||
| """ | ||
| self.data = data | ||
|
|
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file uses tab characters for indentation. The rest of the codebase appears to use 4-space indentation (e.g., src/pyEDM/EDM.py), and tabs increase the risk of TabError if mixed with spaces and can break formatting/linting. Convert indentation to spaces throughout this module.
|
|
||
| @staticmethod | ||
| def find_neighbors(distances: np.ndarray, k: int): | ||
| neighbors = np.argsort(distances, axis = 0)[:k, :] |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.argsort(distances, axis=0) uses an unstable sort by default, so the ordering of equal-distance neighbors can vary across NumPy versions/platforms, making results non-reproducible (especially for degenerate/tied distances). Use a stable sort (e.g., kind="stable"/mergesort) and/or an explicit deterministic tie-breaker to keep pdist behavior reproducible.
| neighbors = np.argsort(distances, axis = 0)[:k, :] | |
| neighbors = np.argsort(distances, axis=0, kind="stable")[:k, :] |
| :param additional_distance: | ||
| :return: | ||
| """ | ||
| out = PairwiseDistanceNeighborFinder(None, None) |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update() constructs PairwiseDistanceNeighborFinder(None, None), leaving out.data as None. If the returned instance is later queried with a new x, distance.cdist(self.data, x, ...) will crash. Preserve the original data in the updated instance (or make query() validate self.data when x is provided).
| out = PairwiseDistanceNeighborFinder(None, None) | |
| out = PairwiseDistanceNeighborFinder(self.data, None) |
| showPlot = False, | ||
| ignoreNan = True, | ||
| returnObject = False ): | ||
| returnObject = False, | ||
| neighbor_algorithm = 'kdtree'): |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A new public neighbor_algorithm option is added for Simplex/SMap, but the test suite does not appear to exercise neighbor_algorithm='pdist' for these APIs. Add a unit/integration test that runs EDM.Simplex(..., neighbor_algorithm='pdist') and EDM.SMap(..., neighbor_algorithm='pdist') on a small non-degenerate dataset (no tied neighbors) to validate the new code path and prevent regressions.
src/pyEDM/EDM.py
Outdated
| compact_nodes = True, | ||
| balanced_tree = True) | ||
| elif self.neighbor_algorithm == 'pdist': | ||
| self.neighbor_finder = PairwiseDistanceNeighborFinder(self.Embedding.iloc[self.lib_i, :].to_numpy()) |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the 'pdist' branch, train is already computed (train = self.Embedding.iloc[self.lib_i, :].to_numpy()), but the code recomputes the same .to_numpy() again when constructing PairwiseDistanceNeighborFinder. Reuse train to avoid an extra conversion/allocation (this can be significant for large embeddings).
| self.neighbor_finder = PairwiseDistanceNeighborFinder(self.Embedding.iloc[self.lib_i, :].to_numpy()) | |
| self.neighbor_finder = PairwiseDistanceNeighborFinder(train) |
| def query(self, | ||
| x: np.ndarray, | ||
| k: int = 1, | ||
| **kwargs) -> Tuple[Union[float, np.ndarray], Union[int, np.ndarray]]: |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overridden method signature does not match call, where it is passed too many arguments. Overriding method method KDTreeNeighborFinder.query matches the call.
Overridden method signature does not match call, where it is passed an argument named 'eps'. Overriding method method KDTreeNeighborFinder.query matches the call.
Overridden method signature does not match call, where it is passed an argument named 'p'. Overriding method method KDTreeNeighborFinder.query matches the call.
Overridden method signature does not match call, where it is passed an argument named 'workers'. Overriding method method KDTreeNeighborFinder.query matches the call.
…ex-mapping logic and neighbor-exclusion logic
…y take a column out of the middle of the neighbors instead of the furthest away neighbor
|
@candytaco : Thank you for this extensive contribution. I have been wanting to do this for a long time to speed up CCM ; ) I hope to be able to look closely at it in the next week, I would like to test it on some heavy CCM loads. Much appreciated! |
This pull request adds the a pairwise distance matrix-based neighbor-finding algorithm for EDM, and the option to switch between KDTree and pairwise distance matrices. The pairwise distance matrices allow for reuse and significantly improves runtimes during repeated queries, e.g. CCM.
API additions for neighbor search:
neighbor_algorithmparameter to the main API functions allowing users to select between different neighbor-finding algorithms (currently'kdtree'and'pdist').neighbor_algorithmparameter is passed to the EDM classes, and the EDM base class uses it to switch between algorithms.neighbor_algorithmiskdtreeto keep internal calculations consistent with previous versions, except in CCM, which ispdistfor performance speedups.Performance improvements:
kdtreethe logic is kept the samepdist, on each sample of each libSize, the distance matrix is copied, and only the sampled library points are kept. This is done by setting the pairwise distances for all other points tonp.inf. This modified distance matrix is then re-queried for nearest neighbors.pdistimplementation offers apdistat least matches KDtree performance (Xeon E3-1245v2) and can provide a ~6x speedup (2x Xeon 5318Y)Codebase changes:
NeighborFinder.pyto be used by EDM classes. They provide a consistent API for both KDTree and pairwise distance-based neighbor finding. TheNeighborFinderBaseclass specifies an interface for adding additional future algorithms. TheKDTreeNeighborFinderandPairwiseDistanceNeighborFinderclasses inherit from the base class.neighbor_finderfield for the EDM base class that is an object inheriting fromNeighborFinderBasethat implement the actual neighbor finding algorithm.FindNeighborsmethod is moved into the EDM class file, and some of its internal logic is factored out into separate methods and properties:EDM.check_lib_valid()for checking that lib indices are validEDM.excludionRadius_knnTo check if exclusion radius adjustment is neededEDM.knn_for getting actual number of nearest neighbors to queryEDM.map_knn_indices_to_data()to map indices returned by neighbor finder to indices in datamap_knn_indices_to_data, to be reused during repeated nearest-neighbor queries without needing to go through the entire instantiation of the neighbor-finding algorithm.Misc
pdistas the algorithm will failtest_simplex10. However, the specific samples that cause the algorithm to differ from the expected predictions have values of <0, 0, 0>, which matches many lib entries that are also <0, 0, 0>. There are more of these degenerate neighbors thank, and the neighbors indices returned by the algorithms are instance-specific to the particular KDTree object or the outcomes of quicksort fornp.argsort, so nearest-neighbor selection is not guaranteed to be deterministic in this case.test_ccm3, the comparison of two arrays separately rounded to 4 decimal points fails whennp.allclosewith an absolute tolerance of 1e-4 would consider them equal, likely because one value gets rounded up and the other down.