Skip to content

Commit 10c2ec4

Browse files
Jammy2211Jammy2211
authored andcommitted
mapper now comes from factory as required
1 parent 6477976 commit 10c2ec4

4 files changed

Lines changed: 58 additions & 19 deletions

File tree

autoarray/inversion/pixelization/mappers/factory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular
99
from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
1010
from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay
11+
from autoarray.structures.mesh.knn_delaunay_2d import Mesh2DDelaunayKNN
1112

1213

1314
def mapper_from(
@@ -49,6 +50,7 @@ def mapper_from(
4950
MapperRectangularUniform,
5051
)
5152
from autoarray.inversion.pixelization.mappers.delaunay import MapperDelaunay
53+
from autoarray.inversion.pixelization.mappers.knn import MapperKNNInterpolator
5254

5355
if isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangularUniform):
5456
return MapperRectangularUniform(
@@ -77,3 +79,12 @@ def mapper_from(
7779
preloads=preloads,
7880
xp=xp,
7981
)
82+
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunayKNN):
83+
return MapperKNNInterpolator(
84+
mapper_grids=mapper_grids,
85+
border_relocator=border_relocator,
86+
regularization=regularization,
87+
settings=settings,
88+
preloads=preloads,
89+
xp=xp,
90+
)

autoarray/inversion/pixelization/mappers/knn.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,16 @@ class MapperKNNInterpolator(AbstractMapper):
1111
Mapper using kNN + compact Wendland kernel interpolation (partition of unity).
1212
"""
1313

14-
# ---- You almost certainly want these as configurable attributes somewhere ----
15-
# If your Mesh / Pixelization already stores these, read them from self.mesh instead.
16-
@property
17-
def k_neighbors(self) -> int:
18-
# e.g. return self.pixelization.k_neighbors or self.mesh.k_neighbors
19-
return getattr(self.pixelization, "k_neighbors", 10)
20-
21-
@property
22-
def kernel(self) -> str:
23-
return getattr(self.pixelization, "kernel", "wendland_c4")
24-
25-
@property
26-
def radius_scale(self) -> float:
27-
return getattr(self.pixelization, "radius_scale", 1.5)
28-
2914
def _pix_sub_weights_from_query_points(self, query_points) -> PixSubWeights:
3015
"""
3116
Compute PixSubWeights for arbitrary query points using the kNN kernel module.
3217
Arrays are created in self._xp (numpy or jax.numpy) from the start.
3318
"""
3419

20+
k_neighbors = 10
21+
kernel = 'wendland_c4'
22+
radius_scale = 1.5
23+
3524
xp = self._xp # numpy or jax.numpy
3625

3726
# ------------------------------------------------------------------
@@ -53,9 +42,9 @@ def _pix_sub_weights_from_query_points(self, query_points) -> PixSubWeights:
5342
weights_jax, indices_jax, _ = get_interpolation_weights(
5443
points=points,
5544
query_points=query_points,
56-
k_neighbors=int(self.k_neighbors),
57-
kernel=self.kernel,
58-
radius_scale=float(self.radius_scale),
45+
k_neighbors=int(k_neighbors),
46+
kernel=kernel,
47+
radius_scale=float(radius_scale),
5948
)
6049

6150
# ------------------------------------------------------------------

autoarray/inversion/pixelization/mesh/knn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,12 @@ def body_fun(i, out_acc):
259259

260260
class KNNInterpolator(Delaunay):
261261

262-
def __init__(self):
262+
def __init__(self, k_neighbors=10, kernel='wendland_c4',
263+
radius_scale=1.5):
264+
265+
self.k_neighbors = k_neighbors
266+
self.kernel = kernel
267+
self.radius_scale = radius_scale
263268

264269
super().__init__()
265270

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay
2+
3+
class Mesh2DDelaunayKNN(Mesh2DDelaunay):
4+
5+
def mesh_grid_from(
6+
self,
7+
source_plane_data_grid=None,
8+
source_plane_mesh_grid=None,
9+
preloads=None,
10+
xp=np,
11+
):
12+
"""
13+
Return the Delaunay ``source_plane_mesh_grid`` as a ``Mesh2DDelaunay`` object, which provides additional
14+
functionality for performing operations that exploit the geometry of a Delaunay mesh.
15+
16+
Parameters
17+
----------
18+
source_plane_data_grid
19+
A 2D grid of (y,x) coordinates associated with the unmasked 2D data after it has been transformed to the
20+
``source`` reference frame.
21+
source_plane_mesh_grid
22+
The centres of every Delaunay pixel in the ``source`` frame, which are initially derived by computing a sparse
23+
set of (y,x) coordinates computed from the unmasked data in the image-plane and applying a transformation
24+
to this.
25+
settings
26+
Settings controlling the pixelization for example if a border is used to relocate its exterior coordinates.
27+
"""
28+
29+
return Mesh2DDelaunayKNN(
30+
values=source_plane_mesh_grid,
31+
source_plane_data_grid_over_sampled=source_plane_data_grid,
32+
preloads=preloads,
33+
_xp=xp,
34+
)

0 commit comments

Comments
 (0)