Skip to content

Commit 6477976

Browse files
Jammy2211Jammy2211
authored andcommitted
added MapperKNNInterpolator
1 parent def14ce commit 6477976

2 files changed

Lines changed: 112 additions & 1 deletion

File tree

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import numpy as np
2+
from autoconf import cached_property
3+
4+
from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper, PixSubWeights
5+
6+
from autoarray.inversion.pixelization.mesh.knn import get_interpolation_weights
7+
8+
9+
class MapperKNNInterpolator(AbstractMapper):
10+
"""
11+
Mapper using kNN + compact Wendland kernel interpolation (partition of unity).
12+
"""
13+
14+
# ---- You almost certainly want these as configurable attributes somewhere ----
15+
# If your Mesh / Pixelization already stores these, read them from self.mesh instead.
16+
@property
17+
def k_neighbors(self) -> int:
18+
# e.g. return self.pixelization.k_neighbors or self.mesh.k_neighbors
19+
return getattr(self.pixelization, "k_neighbors", 10)
20+
21+
@property
22+
def kernel(self) -> str:
23+
return getattr(self.pixelization, "kernel", "wendland_c4")
24+
25+
@property
26+
def radius_scale(self) -> float:
27+
return getattr(self.pixelization, "radius_scale", 1.5)
28+
29+
def _pix_sub_weights_from_query_points(self, query_points) -> PixSubWeights:
30+
"""
31+
Compute PixSubWeights for arbitrary query points using the kNN kernel module.
32+
Arrays are created in self._xp (numpy or jax.numpy) from the start.
33+
"""
34+
35+
xp = self._xp # numpy or jax.numpy
36+
37+
# ------------------------------------------------------------------
38+
# Source nodes (pixelization "pixels") on the source-plane mesh grid
39+
# Shape: (N, 2)
40+
# ------------------------------------------------------------------
41+
points = xp.asarray(self.source_plane_mesh_grid.array, dtype=xp.float64)
42+
43+
# ------------------------------------------------------------------
44+
# Query points (oversampled source-plane data grid or split points)
45+
# Shape: (M, 2)
46+
# ------------------------------------------------------------------
47+
query_points = xp.asarray(query_points, dtype=xp.float64)
48+
49+
# ------------------------------------------------------------------
50+
# kNN kernel weights (runs in JAX, but accepts NumPy or JAX inputs)
51+
# Always returns JAX arrays
52+
# ------------------------------------------------------------------
53+
weights_jax, indices_jax, _ = get_interpolation_weights(
54+
points=points,
55+
query_points=query_points,
56+
k_neighbors=int(self.k_neighbors),
57+
kernel=self.kernel,
58+
radius_scale=float(self.radius_scale),
59+
)
60+
61+
# ------------------------------------------------------------------
62+
# Convert outputs to xp backend *only if needed*
63+
# ------------------------------------------------------------------
64+
if xp is jnp:
65+
weights = weights_jax
66+
mappings = indices_jax
67+
else:
68+
# xp is numpy
69+
weights = np.asarray(weights_jax)
70+
mappings = np.asarray(indices_jax)
71+
72+
# ------------------------------------------------------------------
73+
# Sizes: always k for kNN
74+
# Shape: (M,)
75+
# ------------------------------------------------------------------
76+
sizes = xp.full(
77+
(mappings.shape[0],),
78+
mappings.shape[1],
79+
dtype=xp.int32,
80+
)
81+
82+
# Ensure correct dtypes
83+
mappings = mappings.astype(xp.int32)
84+
weights = weights.astype(xp.float64)
85+
86+
return PixSubWeights(
87+
mappings=mappings,
88+
sizes=sizes,
89+
weights=weights,
90+
)
91+
92+
@cached_property
93+
def pix_sub_weights(self) -> PixSubWeights:
94+
"""
95+
kNN mappings + kernel weights for every oversampled source-plane data-grid point.
96+
"""
97+
return self._pix_sub_weights_from_query_points(
98+
query_points=self.source_plane_data_grid.over_sampled
99+
)
100+
101+
@property
102+
def pix_sub_weights_split_points(self) -> PixSubWeights:
103+
"""
104+
kNN mappings + kernel weights computed at split points (for split regularization schemes).
105+
"""
106+
# Your Delaunay mesh exposes split points via self.delaunay.split_points.
107+
# For KNN mesh, you should expose the same property. If not, route appropriately:
108+
# split_points = self.mesh.split_points
109+
split_points = self.delaunay.split_points # keep consistent with existing API
110+
111+
return self._pix_sub_weights_from_query_points(query_points=split_points)

autoarray/inversion/pixelization/mesh/knn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,4 +262,4 @@ class KNNInterpolator(Delaunay):
262262
def __init__(self):
263263

264264
super().__init__()
265-
265+

0 commit comments

Comments
 (0)