@@ -144,6 +144,97 @@ def data_slim_to_pixelization_unique_from(
144144 return data_to_pix_unique , data_weights , pix_lengths
145145
146146
147+ def rectangular_mappings_weights_via_interpolation_from (
148+ shape_native : Tuple [int , int ],
149+ source_plane_data_grid : np .ndarray ,
150+ source_plane_mesh_grid : np .ndarray ,
151+ ):
152+ """
153+ Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid.
154+
155+ Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function
156+ determines for each irregular point:
157+ - the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and
158+ - the bilinear interpolation weights with respect to those pixels.
159+
160+ The function supports JAX and is compatible with JIT compilation.
161+
162+ Parameters
163+ ----------
164+ shape_native
165+ The shape (Ny, Nx) of the original rectangular mesh grid before flattening.
166+ source_plane_data_grid
167+ The irregular grid of (y, x) points to interpolate.
168+ source_plane_mesh_grid
169+ The flattened regular rectangular mesh grid of (y, x) coordinates.
170+
171+ Returns
172+ -------
173+ mappings : np.ndarray of shape (N, 4)
174+ Indices of the four nearest rectangular mesh pixels in the flattened mesh grid.
175+ Order is: top-left, top-right, bottom-left, bottom-right.
176+ weights : np.ndarray of shape (N, 4)
177+ Bilinear interpolation weights corresponding to the four nearest mesh pixels.
178+
179+ Notes
180+ -----
181+ - Assumes the mesh grid is uniformly spaced.
182+ - The weights sum to 1 for each irregular point.
183+ - Uses bilinear interpolation in the (y, x) coordinate system.
184+ """
185+ source_plane_mesh_grid = source_plane_mesh_grid .reshape (* shape_native , 2 )
186+
187+ # Assume mesh is shaped (Ny, Nx, 2)
188+ Ny , Nx = source_plane_mesh_grid .shape [:2 ]
189+
190+ # Get mesh spacings and lower corner
191+ y_coords = source_plane_mesh_grid [:, 0 , 0 ] # shape (Ny,)
192+ x_coords = source_plane_mesh_grid [0 , :, 1 ] # shape (Nx,)
193+
194+ dy = y_coords [1 ] - y_coords [0 ]
195+ dx = x_coords [1 ] - x_coords [0 ]
196+
197+ y_min = y_coords [0 ]
198+ x_min = x_coords [0 ]
199+
200+ # shape (N_irregular, 2)
201+ irregular = source_plane_data_grid
202+
203+ # Compute normalized mesh coordinates (floating indices)
204+ fy = (irregular [:, 0 ] - y_min ) / dy
205+ fx = (irregular [:, 1 ] - x_min ) / dx
206+
207+ # Integer indices of top-left corners
208+ ix = np .floor (fx ).astype (np .int32 )
209+ iy = np .floor (fy ).astype (np .int32 )
210+
211+ # Clip to stay within bounds
212+ ix = np .clip (ix , 0 , Nx - 2 )
213+ iy = np .clip (iy , 0 , Ny - 2 )
214+
215+ # Local coordinates inside the cell (0 <= tx, ty <= 1)
216+ tx = fx - ix
217+ ty = fy - iy
218+
219+ # Bilinear weights
220+ w00 = (1 - tx ) * (1 - ty )
221+ w10 = tx * (1 - ty )
222+ w01 = (1 - tx ) * ty
223+ w11 = tx * ty
224+
225+ weights = np .stack ([w00 , w10 , w01 , w11 ], axis = 1 ) # shape (N_irregular, 4)
226+
227+ # Compute indices of 4 surrounding pixels in the flattened mesh
228+ i00 = iy * Nx + ix
229+ i10 = iy * Nx + (ix + 1 )
230+ i01 = (iy + 1 ) * Nx + ix
231+ i11 = (iy + 1 ) * Nx + (ix + 1 )
232+
233+ mappings = np .stack ([i00 , i10 , i01 , i11 ], axis = 1 ) # shape (N_irregular, 4)
234+
235+ return mappings , weights
236+
237+
147238@numba_util .jit ()
148239def pix_indexes_for_sub_slim_index_delaunay_from (
149240 source_plane_data_grid ,
0 commit comments