|
| 1 | +import numpy as np |
| 2 | +from autoconf import cached_property |
| 3 | + |
| 4 | +from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper, PixSubWeights |
| 5 | + |
| 6 | +from autoarray.inversion.pixelization.mesh.knn import get_interpolation_weights |
| 7 | + |
| 8 | + |
| 9 | +class MapperKNNInterpolator(AbstractMapper): |
| 10 | + """ |
| 11 | + Mapper using kNN + compact Wendland kernel interpolation (partition of unity). |
| 12 | + """ |
| 13 | + |
| 14 | + # ---- You almost certainly want these as configurable attributes somewhere ---- |
| 15 | + # If your Mesh / Pixelization already stores these, read them from self.mesh instead. |
| 16 | + @property |
| 17 | + def k_neighbors(self) -> int: |
| 18 | + # e.g. return self.pixelization.k_neighbors or self.mesh.k_neighbors |
| 19 | + return getattr(self.pixelization, "k_neighbors", 10) |
| 20 | + |
| 21 | + @property |
| 22 | + def kernel(self) -> str: |
| 23 | + return getattr(self.pixelization, "kernel", "wendland_c4") |
| 24 | + |
| 25 | + @property |
| 26 | + def radius_scale(self) -> float: |
| 27 | + return getattr(self.pixelization, "radius_scale", 1.5) |
| 28 | + |
| 29 | + def _pix_sub_weights_from_query_points(self, query_points) -> PixSubWeights: |
| 30 | + """ |
| 31 | + Compute PixSubWeights for arbitrary query points using the kNN kernel module. |
| 32 | + Arrays are created in self._xp (numpy or jax.numpy) from the start. |
| 33 | + """ |
| 34 | + |
| 35 | + xp = self._xp # numpy or jax.numpy |
| 36 | + |
| 37 | + # ------------------------------------------------------------------ |
| 38 | + # Source nodes (pixelization "pixels") on the source-plane mesh grid |
| 39 | + # Shape: (N, 2) |
| 40 | + # ------------------------------------------------------------------ |
| 41 | + points = xp.asarray(self.source_plane_mesh_grid.array, dtype=xp.float64) |
| 42 | + |
| 43 | + # ------------------------------------------------------------------ |
| 44 | + # Query points (oversampled source-plane data grid or split points) |
| 45 | + # Shape: (M, 2) |
| 46 | + # ------------------------------------------------------------------ |
| 47 | + query_points = xp.asarray(query_points, dtype=xp.float64) |
| 48 | + |
| 49 | + # ------------------------------------------------------------------ |
| 50 | + # kNN kernel weights (runs in JAX, but accepts NumPy or JAX inputs) |
| 51 | + # Always returns JAX arrays |
| 52 | + # ------------------------------------------------------------------ |
| 53 | + weights_jax, indices_jax, _ = get_interpolation_weights( |
| 54 | + points=points, |
| 55 | + query_points=query_points, |
| 56 | + k_neighbors=int(self.k_neighbors), |
| 57 | + kernel=self.kernel, |
| 58 | + radius_scale=float(self.radius_scale), |
| 59 | + ) |
| 60 | + |
| 61 | + # ------------------------------------------------------------------ |
| 62 | + # Convert outputs to xp backend *only if needed* |
| 63 | + # ------------------------------------------------------------------ |
| 64 | + if xp is jnp: |
| 65 | + weights = weights_jax |
| 66 | + mappings = indices_jax |
| 67 | + else: |
| 68 | + # xp is numpy |
| 69 | + weights = np.asarray(weights_jax) |
| 70 | + mappings = np.asarray(indices_jax) |
| 71 | + |
| 72 | + # ------------------------------------------------------------------ |
| 73 | + # Sizes: always k for kNN |
| 74 | + # Shape: (M,) |
| 75 | + # ------------------------------------------------------------------ |
| 76 | + sizes = xp.full( |
| 77 | + (mappings.shape[0],), |
| 78 | + mappings.shape[1], |
| 79 | + dtype=xp.int32, |
| 80 | + ) |
| 81 | + |
| 82 | + # Ensure correct dtypes |
| 83 | + mappings = mappings.astype(xp.int32) |
| 84 | + weights = weights.astype(xp.float64) |
| 85 | + |
| 86 | + return PixSubWeights( |
| 87 | + mappings=mappings, |
| 88 | + sizes=sizes, |
| 89 | + weights=weights, |
| 90 | + ) |
| 91 | + |
| 92 | + @cached_property |
| 93 | + def pix_sub_weights(self) -> PixSubWeights: |
| 94 | + """ |
| 95 | + kNN mappings + kernel weights for every oversampled source-plane data-grid point. |
| 96 | + """ |
| 97 | + return self._pix_sub_weights_from_query_points( |
| 98 | + query_points=self.source_plane_data_grid.over_sampled |
| 99 | + ) |
| 100 | + |
| 101 | + @property |
| 102 | + def pix_sub_weights_split_points(self) -> PixSubWeights: |
| 103 | + """ |
| 104 | + kNN mappings + kernel weights computed at split points (for split regularization schemes). |
| 105 | + """ |
| 106 | + # Your Delaunay mesh exposes split points via self.delaunay.split_points. |
| 107 | + # For KNN mesh, you should expose the same property. If not, route appropriately: |
| 108 | + # split_points = self.mesh.split_points |
| 109 | + split_points = self.delaunay.split_points # keep consistent with existing API |
| 110 | + |
| 111 | + return self._pix_sub_weights_from_query_points(query_points=split_points) |
0 commit comments