Skip to content

Commit c226c2d

Browse files
Jammy2211Jammy2211
authored andcommitted
rectangular grid uses interpolation
1 parent d8dee45 commit c226c2d

2 files changed

Lines changed: 101 additions & 11 deletions

File tree

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,97 @@ def data_slim_to_pixelization_unique_from(
144144
return data_to_pix_unique, data_weights, pix_lengths
145145

146146

147+
def rectangular_mappings_weights_via_interpolation_from(
148+
shape_native: Tuple[int, int],
149+
source_plane_data_grid: np.ndarray,
150+
source_plane_mesh_grid: np.ndarray,
151+
):
152+
"""
153+
Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid.
154+
155+
Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function
156+
determines for each irregular point:
157+
- the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and
158+
- the bilinear interpolation weights with respect to those pixels.
159+
160+
The function supports JAX and is compatible with JIT compilation.
161+
162+
Parameters
163+
----------
164+
shape_native
165+
The shape (Ny, Nx) of the original rectangular mesh grid before flattening.
166+
source_plane_data_grid
167+
The irregular grid of (y, x) points to interpolate.
168+
source_plane_mesh_grid
169+
The flattened regular rectangular mesh grid of (y, x) coordinates.
170+
171+
Returns
172+
-------
173+
mappings : np.ndarray of shape (N, 4)
174+
Indices of the four nearest rectangular mesh pixels in the flattened mesh grid.
175+
Order is: top-left, top-right, bottom-left, bottom-right.
176+
weights : np.ndarray of shape (N, 4)
177+
Bilinear interpolation weights corresponding to the four nearest mesh pixels.
178+
179+
Notes
180+
-----
181+
- Assumes the mesh grid is uniformly spaced.
182+
- The weights sum to 1 for each irregular point.
183+
- Uses bilinear interpolation in the (y, x) coordinate system.
184+
"""
185+
source_plane_mesh_grid = source_plane_mesh_grid.reshape(*shape_native, 2)
186+
187+
# Assume mesh is shaped (Ny, Nx, 2)
188+
Ny, Nx = source_plane_mesh_grid.shape[:2]
189+
190+
# Get mesh spacings and lower corner
191+
y_coords = source_plane_mesh_grid[:, 0, 0] # shape (Ny,)
192+
x_coords = source_plane_mesh_grid[0, :, 1] # shape (Nx,)
193+
194+
dy = y_coords[1] - y_coords[0]
195+
dx = x_coords[1] - x_coords[0]
196+
197+
y_min = y_coords[0]
198+
x_min = x_coords[0]
199+
200+
# shape (N_irregular, 2)
201+
irregular = source_plane_data_grid
202+
203+
# Compute normalized mesh coordinates (floating indices)
204+
fy = (irregular[:, 0] - y_min) / dy
205+
fx = (irregular[:, 1] - x_min) / dx
206+
207+
# Integer indices of top-left corners
208+
ix = np.floor(fx).astype(np.int32)
209+
iy = np.floor(fy).astype(np.int32)
210+
211+
# Clip to stay within bounds
212+
ix = np.clip(ix, 0, Nx - 2)
213+
iy = np.clip(iy, 0, Ny - 2)
214+
215+
# Local coordinates inside the cell (0 <= tx, ty <= 1)
216+
tx = fx - ix
217+
ty = fy - iy
218+
219+
# Bilinear weights
220+
w00 = (1 - tx) * (1 - ty)
221+
w10 = tx * (1 - ty)
222+
w01 = (1 - tx) * ty
223+
w11 = tx * ty
224+
225+
weights = np.stack([w00, w10, w01, w11], axis=1) # shape (N_irregular, 4)
226+
227+
# Compute indices of 4 surrounding pixels in the flattened mesh
228+
i00 = iy * Nx + ix
229+
i10 = iy * Nx + (ix + 1)
230+
i01 = (iy + 1) * Nx + ix
231+
i11 = (iy + 1) * Nx + (ix + 1)
232+
233+
mappings = np.stack([i00, i10, i01, i11], axis=1) # shape (N_irregular, 4)
234+
235+
return mappings, weights
236+
237+
147238
@numba_util.jit()
148239
def pix_indexes_for_sub_slim_index_delaunay_from(
149240
source_plane_data_grid,

autoarray/inversion/pixelization/mappers/rectangular.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from autoarray.numba_util import profile_func
1010
from autoarray.geometry import geometry_util
1111

12+
from autoarray.inversion.pixelization.mappers import mapper_util
1213

1314
class MapperRectangular(AbstractMapper):
1415
"""
@@ -99,19 +100,17 @@ def pix_sub_weights(self) -> PixSubWeights:
99100
dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index`
100101
are equal to 1.0.
101102
"""
102-
mappings = geometry_util.grid_pixel_indexes_2d_slim_from(
103-
grid_scaled_2d_slim=np.array(self.source_plane_data_grid.over_sampled),
104-
shape_native=self.source_plane_mesh_grid.shape_native,
105-
pixel_scales=self.source_plane_mesh_grid.pixel_scales,
106-
origin=self.source_plane_mesh_grid.origin,
107-
).astype("int")
108103

109-
mappings = mappings.reshape((len(mappings), 1))
104+
mappings, weights = (
105+
mapper_util.rectangular_mappings_weights_via_interpolation_from(
106+
shape_native=self.shape_native,
107+
source_plane_mesh_grid=self.source_plane_mesh_grid.array,
108+
source_plane_data_grid=self.source_plane_data_grid.over_sampled,
109+
)
110+
)
110111

111112
return PixSubWeights(
112113
mappings=mappings,
113-
sizes=np.ones(len(mappings), dtype="int"),
114-
weights=np.ones(
115-
(len(self.source_plane_data_grid.over_sampled), 1), dtype="int"
116-
),
114+
sizes=4 * np.ones(len(mappings), dtype="int"),
115+
weights=weights,
117116
)

0 commit comments

Comments
 (0)