Skip to content

Commit def14ce

Browse files
Jammy2211Jammy2211
authored andcommitted
added all functionality and KNNInterpolator class
1 parent b2dd830 commit def14ce

1 file changed

Lines changed: 265 additions & 0 deletions

File tree

  • autoarray/inversion/pixelization/mesh
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
"""
2+
Optimized Kernel-Based Interpolation in JAX
3+
Uses Wendland compactly supported kernels with normalized weights (partition of unity).
4+
More robust and faster than MLS, better accuracy than simple IDW.
5+
"""
6+
import jax
7+
import jax.numpy as jnp
8+
from functools import partial
9+
10+
11+
def get_interpolation_weights(points, query_points, k_neighbors=10, kernel='wendland_c4',
12+
radius_scale=1.5):
13+
"""
14+
Compute interpolation weights between source points and query points.
15+
16+
This is a standalone function to get the weights used in kernel interpolation,
17+
useful when you want to analyze or reuse weights separately from interpolation.
18+
19+
Args:
20+
points: (N, 2) source point coordinates
21+
query_points: (M, 2) query point coordinates
22+
k_neighbors: number of nearest neighbors (default: 10)
23+
kernel: 'wendland_c2', 'wendland_c4', or 'wendland_c6' (default: 'wendland_c4')
24+
radius_scale: multiplier for auto-computed radius (default: 1.5)
25+
26+
Returns:
27+
weights: (M, k) normalized weights for each query point
28+
indices: (M, k) indices of K nearest neighbors in points array
29+
distances: (M, k) distances to K nearest neighbors
30+
31+
Example:
32+
>>> weights, indices, distances = get_interpolation_weights(src_pts, query_pts)
33+
>>> # Now you can use weights and indices for custom interpolation
34+
>>> interpolated = jnp.sum(weights * values[indices], axis=1)
35+
"""
36+
points = jnp.asarray(points)
37+
query_points = jnp.asarray(query_points)
38+
39+
# Select kernel function
40+
if kernel == 'wendland_c2':
41+
kernel_fn = wendland_c2
42+
elif kernel == 'wendland_c4':
43+
kernel_fn = wendland_c4
44+
elif kernel == 'wendland_c6':
45+
kernel_fn = wendland_c6
46+
else:
47+
raise ValueError(f"Unknown kernel: {kernel}")
48+
49+
return compute_weights(points, query_points, k_neighbors, radius_scale, kernel_fn)
50+
51+
52+
def kernel_interpolate(points, values, query_points, k_neighbors=10, kernel='wendland_c4',
53+
radius_scale=1.5, chunk_size=None):
54+
"""
55+
Kernel-based interpolation using K-nearest neighbors with Wendland kernels.
56+
57+
Uses normalized kernel weights ensuring partition of unity for better accuracy.
58+
More robust than MLS (no linear solve) and more accurate than simple 1/d^p.
59+
60+
Args:
61+
points: (N, 2) source point coordinates
62+
values: (N,) values at source points
63+
query_points: (M, 2) query point coordinates
64+
k_neighbors: number of nearest neighbors (default: 10)
65+
kernel: 'wendland_c2', 'wendland_c4', or 'wendland_c6' (default: 'wendland_c4')
66+
radius_scale: multiplier for auto-computed radius (default: 1.5)
67+
chunk_size: if provided, process queries in chunks
68+
69+
Returns:
70+
(M,) interpolated values
71+
"""
72+
points = jnp.asarray(points)
73+
values = jnp.asarray(values)
74+
query_points = jnp.asarray(query_points)
75+
76+
# Select kernel function
77+
if kernel == 'wendland_c2':
78+
kernel_fn = wendland_c2
79+
elif kernel == 'wendland_c4':
80+
kernel_fn = wendland_c4
81+
elif kernel == 'wendland_c6':
82+
kernel_fn = wendland_c6
83+
else:
84+
raise ValueError(f"Unknown kernel: {kernel}")
85+
86+
if chunk_size is None:
87+
return _kernel_knn_jit(points, values, query_points, k_neighbors,
88+
radius_scale, kernel_fn)
89+
else:
90+
return _kernel_chunked(points, values, query_points, k_neighbors,
91+
radius_scale, kernel_fn, int(chunk_size))
92+
93+
94+
def wendland_c2(r, h):
95+
"""
96+
Wendland C2: (1 - r/h)^4 * (4*r/h + 1)
97+
C2 continuous, compact support
98+
"""
99+
s = r / (h + 1e-10)
100+
w = jnp.where(s < 1.0, (1.0 - s) ** 4 * (4.0 * s + 1.0), 0.0)
101+
return w
102+
103+
104+
def wendland_c4(r, h):
105+
"""
106+
Wendland C4: (1 - r/h)^6 * (35*(r/h)^2 + 18*r/h + 3)
107+
C4 continuous, smoother, compact support
108+
"""
109+
s = r / (h + 1e-10)
110+
w = jnp.where(s < 1.0, (1.0 - s) ** 6 * (35.0 * s ** 2 + 18.0 * s + 3.0), 0.0)
111+
return w
112+
113+
114+
def wendland_c6(r, h):
115+
"""
116+
Wendland C6: (1 - r/h)^8 * (32*(r/h)^3 + 25*(r/h)^2 + 8*r/h + 1)
117+
C6 continuous, very smooth, compact support
118+
"""
119+
s = r / (h + 1e-10)
120+
w = jnp.where(s < 1.0, (1.0 - s) ** 8 * (32.0 * s ** 3 + 25.0 * s ** 2 + 8.0 * s + 1.0), 0.0)
121+
return w
122+
123+
124+
def compute_weights(points, query_points, k_neighbors, radius_scale, kernel_fn):
125+
"""
126+
Compute normalized kernel weights for interpolation.
127+
128+
This function computes the weights between source points and
129+
query points using K-nearest neighbors and Wendland kernels.
130+
131+
Args:
132+
points: (N, 2) source point coordinates
133+
query_points: (M, 2) query point coordinates
134+
k_neighbors: number of nearest neighbors
135+
radius_scale: multiplier for auto-computed radius
136+
kernel_fn: kernel function (wendland_c2/c4/c6)
137+
138+
Returns:
139+
weights: (M, k) normalized weights for each query point
140+
indices: (M, k) indices of K nearest neighbors for each query point
141+
distances: (M, k) distances to K nearest neighbors
142+
"""
143+
# Compute pairwise distances
144+
diff = query_points[:, None, :] - points[None, :, :] # (M, N, 2)
145+
dist_sq = jnp.sum(diff * diff, axis=-1) # (M, N)
146+
dist = jnp.sqrt(dist_sq) # (M, N)
147+
148+
# Find K nearest neighbors
149+
top_k_vals, top_k_indices = jax.lax.top_k(-dist, k_neighbors) # negative for smallest
150+
knn_distances = -top_k_vals # (M, k)
151+
152+
# Auto-compute radius: use max KNN distance + margin
153+
h = jnp.max(knn_distances, axis=1, keepdims=True) * radius_scale # (M, 1)
154+
155+
# Compute kernel weights
156+
weights = kernel_fn(knn_distances, h) # (M, k)
157+
158+
# Normalize weights (partition of unity)
159+
# Add small epsilon to avoid division by zero
160+
weight_sum = jnp.sum(weights, axis=1, keepdims=True) + 1e-10 # (M, 1)
161+
weights_normalized = weights / weight_sum # (M, k)
162+
163+
return weights_normalized, top_k_indices, knn_distances
164+
165+
166+
def _compute_kernel_knn(query_chunk, points, values, k, radius_scale, kernel_fn):
167+
"""
168+
Compute kernel interpolation for a chunk of query points using K nearest neighbors.
169+
170+
Args:
171+
query_chunk: (M, 2) query points
172+
points: (N, 2) source points
173+
values: (N,) values at source points
174+
k: number of nearest neighbors
175+
radius_scale: multiplier for radius
176+
kernel_fn: kernel function
177+
178+
Returns:
179+
(M,) interpolated values
180+
"""
181+
# Compute weights using the intermediate function
182+
weights_normalized, top_k_indices, _ = compute_weights(
183+
points, query_chunk, k, radius_scale, kernel_fn
184+
)
185+
186+
# Get neighbor values
187+
neighbor_values = values[top_k_indices] # (M, k)
188+
189+
# Interpolate: weighted sum
190+
interpolated = jnp.sum(weights_normalized * neighbor_values, axis=1) # (M,)
191+
192+
return interpolated
193+
194+
195+
@partial(jax.jit, static_argnames=("k_neighbors", "kernel_fn"))
196+
def _kernel_knn_jit(points, values, query_points, k_neighbors, radius_scale, kernel_fn):
197+
"""
198+
JIT-compiled kernel interpolation.
199+
"""
200+
return _compute_kernel_knn(query_points, points, values, k_neighbors,
201+
radius_scale, kernel_fn)
202+
203+
204+
def _kernel_chunked(points, values, query_points, k_neighbors, radius_scale,
205+
kernel_fn, chunk_size):
206+
"""
207+
Chunked kernel interpolation for memory efficiency.
208+
"""
209+
M = query_points.shape[0]
210+
D = query_points.shape[1]
211+
212+
# Pad queries
213+
remainder = M % chunk_size
214+
pad = 0 if remainder == 0 else (chunk_size - remainder)
215+
if pad:
216+
qp_pad = jnp.pad(query_points, ((0, pad), (0, 0)))
217+
else:
218+
qp_pad = query_points
219+
220+
out_pad = _kernel_chunked_jit(points, values, qp_pad, k_neighbors,
221+
radius_scale, kernel_fn, chunk_size)
222+
return out_pad[:M]
223+
224+
225+
@partial(jax.jit, static_argnames=("k_neighbors", "kernel_fn", "chunk_size"))
226+
def _kernel_chunked_jit(points, values, query_points_padded, k_neighbors,
227+
radius_scale, kernel_fn, chunk_size):
228+
"""
229+
JIT-compiled chunked kernel interpolation.
230+
"""
231+
M_pad = query_points_padded.shape[0]
232+
D = points.shape[1]
233+
n_chunks = M_pad // chunk_size
234+
235+
out = jnp.zeros((M_pad,), dtype=values.dtype)
236+
237+
def body_fun(i, out_acc):
238+
start = i * chunk_size
239+
240+
# Extract chunk
241+
q_chunk = jax.lax.dynamic_slice(
242+
query_points_padded, (start, 0), (chunk_size, D)
243+
)
244+
245+
# Compute kernel interpolation for this chunk
246+
result_chunk = _compute_kernel_knn(q_chunk, points, values, k_neighbors,
247+
radius_scale, kernel_fn)
248+
249+
# Update output
250+
out_acc = jax.lax.dynamic_update_slice(out_acc, result_chunk, (start,))
251+
252+
return out_acc
253+
254+
out = jax.lax.fori_loop(0, n_chunks, body_fun, out)
255+
return out
256+
257+
258+
from autoarray.inversion.pixelization.mesh.delaunay import Delaunay
259+
260+
class KNNInterpolator(Delaunay):
261+
262+
def __init__(self):
263+
264+
super().__init__()
265+

0 commit comments

Comments
 (0)