Skip to content

Commit b68064e

Browse files
committed
Added method which reshapes a mask when loaded from a fits file.
1 parent 107f882 commit b68064e

2 files changed

Lines changed: 31 additions & 2 deletions

File tree

autoarray/mask/mask.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def elliptical_annular(
373373
)
374374

375375
@classmethod
376-
def from_fits(cls, file_path, pixel_scales, hdu=0, sub_size=1, origin=(0.0, 0.0)):
376+
def from_fits(cls, file_path, pixel_scales, hdu=0, sub_size=1, origin=(0.0, 0.0), resized_mask_shape=None):
377377
"""
378378
Loads the image from a .fits file.
379379
@@ -391,13 +391,19 @@ def from_fits(cls, file_path, pixel_scales, hdu=0, sub_size=1, origin=(0.0, 0.0)
391391
if type(pixel_scales) is float or int:
392392
pixel_scales = (float(pixel_scales), float(pixel_scales))
393393

394-
return cls(
394+
mask = cls(
395395
array_util.numpy_array_2d_from_fits(file_path=file_path, hdu=hdu),
396396
pixel_scales=pixel_scales,
397397
sub_size=sub_size,
398398
origin=origin,
399399
)
400400

401+
if resized_mask_shape is not None:
402+
403+
mask = mask.mapping.resized_mask_from_new_shape(new_shape=resized_mask_shape)
404+
405+
return mask
406+
401407
def output_to_fits(self, file_path, overwrite=False):
402408

403409
array_util.numpy_array_2d_to_fits(

test_autoarray/unit/mask/test_mask.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ def test__mask_elliptical_annular_inverted__compare_to_array_util(self):
427427

428428

429429
class TestFromAndToFits:
430+
431+
430432
def test__load_and_output_mask_to_fits(self):
431433

432434
mask = msk.Mask.from_fits(
@@ -459,6 +461,27 @@ def test__load_and_output_mask_to_fits(self):
459461
assert mask.pixel_scales == (1.0, 1.0)
460462
assert mask.origin == (2.0, 2.0)
461463

464+
def test__load_from_fits_with_resized_mask_shape(self):
465+
466+
mask = msk.Mask.from_fits(
467+
file_path=test_data_dir + "3x3_ones.fits",
468+
hdu=0,
469+
sub_size=1,
470+
pixel_scales=(1.0, 1.0),
471+
resized_mask_shape=(1, 1)
472+
)
473+
474+
assert mask.shape_2d == (1, 1)
475+
476+
mask = msk.Mask.from_fits(
477+
file_path=test_data_dir + "3x3_ones.fits",
478+
hdu=0,
479+
sub_size=1,
480+
pixel_scales=(1.0, 1.0),
481+
resized_mask_shape=(5, 5)
482+
)
483+
484+
assert mask.shape_2d == (5, 5)
462485

463486
class TestSubQuantities:
464487
def test__sub_shape_is_shape_times_sub_size(self):

0 commit comments

Comments
 (0)