diff --git a/ctis/instruments/_instruments.py b/ctis/instruments/_instruments.py index 6d02cf0..748a579 100644 --- a/ctis/instruments/_instruments.py +++ b/ctis/instruments/_instruments.py @@ -39,7 +39,7 @@ class AbstractInstrument( @abc.abstractmethod def image( self, - scene: na.AbstractScalar, + scene: na.AbstractScalar | na.AbstractFunctionArray, integrate: bool = True, noise: bool = True, ) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]: @@ -66,7 +66,7 @@ def image( @abc.abstractmethod def backproject( self, - image: na.AbstractScalar, + image: na.AbstractScalar | na.AbstractFunctionArray, integrate: bool = True, ) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]: """ @@ -233,11 +233,18 @@ def _energy_per_photon(self) -> u.Quantity | na.AbstractScalar: def image( self, - scene: na.AbstractScalar, + scene: na.AbstractScalar | na.AbstractFunctionArray, integrate: bool = True, noise: bool = True, ) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]: + if isinstance(scene, na.AbstractFunctionArray): + if not np.all(scene.inputs == self.coordinates_scene): + raise ValueError( + "`scene.inputs` and `self.coordinates_scene` are not equal." + ) + scene = scene.outputs + values_input = scene * self._volume_scene values_input = values_input / self._energy_per_photon @@ -279,10 +286,17 @@ def image( def backproject( self, - image: na.AbstractScalar, + image: na.AbstractScalar | na.AbstractFunctionArray, integrate: bool = True, ) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]: + if isinstance(image, na.AbstractFunctionArray): + if not np.all(image.inputs.position == self.coordinates_sensor.position): + raise ValueError( + "`image.inputs` and `self.coordinates_sensor` are not equal." + ) + image = image.outputs + coordinates = self.coordinates_scene axis_wavelength = self.axis_wavelength @@ -514,7 +528,7 @@ def weights_transpose(self): def image( self, - scene: na.AbstractScalar, + scene: na.AbstractScalar | na.AbstractFunctionArray, integrate: bool = True, noise: bool = True, ) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]: @@ -529,7 +543,7 @@ def image( def backproject( self, - image: na.AbstractScalar, + image: na.AbstractScalar | na.AbstractFunctionArray, integrate: bool = True, ) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]: diff --git a/ctis/instruments/_instruments_test.py b/ctis/instruments/_instruments_test.py index 0f64c44..0e06ff1 100644 --- a/ctis/instruments/_instruments_test.py +++ b/ctis/instruments/_instruments_test.py @@ -70,7 +70,7 @@ class AbstractTestAbstractInstrument( def test_image( self, a: ctis.instruments.AbstractInstrument, - scene: na.AbstractScalar, + scene: na.AbstractScalar | na.AbstractFunctionArray, ): result = a.image(scene) assert np.all(result.inputs.position == coordinates_sensor.position) @@ -79,20 +79,23 @@ def test_image( @pytest.mark.parametrize( argnames="image", argvalues=[ - instrument_ideal.image(gaussians.outputs, noise=False).outputs, + instrument_ideal.image(gaussians, noise=False), ], ) def test_backproject( self, a: ctis.instruments.AbstractInstrument, - image: na.AbstractScalar, + image: na.AbstractScalar | na.AbstractFunctionArray, ): result = a.backproject(image) assert np.all(result.inputs == coordinates_scene) assert result.outputs.sum() > 0 - image_check = a.image(result.outputs, noise=False).outputs + if isinstance(image, na.AbstractFunctionArray): + image = image.outputs + + image_check = a.image(result, noise=False).outputs assert np.allclose(image.sum(), image_check.sum()) diff --git a/docs/tutorials/ideal-instrument.ipynb b/docs/tutorials/ideal-instrument.ipynb index 8b5b955..708ae7d 100644 --- a/docs/tutorials/ideal-instrument.ipynb +++ b/docs/tutorials/ideal-instrument.ipynb @@ -501,9 +501,7 @@ "tags": [] }, "outputs": [], - "source": [ - "image = instrument.image(scene.outputs, integrate=False)" - ] + "source": "image = instrument.image(scene, integrate=False)" }, { "cell_type": "raw", @@ -602,9 +600,7 @@ "tags": [] }, "outputs": [], - "source": [ - "image_sum = instrument.image(scene.outputs)" - ] + "source": "image_sum = instrument.image(scene)" }, { "cell_type": "raw", @@ -633,9 +629,7 @@ "tags": [] }, "outputs": [], - "source": [ - "backprojected = instrument.backproject(image_sum.outputs)" - ] + "source": "backprojected = instrument.backproject(image_sum)" }, { "cell_type": "raw", diff --git a/pyproject.toml b/pyproject.toml index 3e5a422..cfbf262 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,3 +47,10 @@ Documentation = "https://ctis.readthedocs.io/en/latest" packages = ["ctis"] [tool.setuptools_scm] + +[tool.coverage.report] +exclude_also = [ + "return NotImplemented", + "raise ValueError", + "raise NotImplementedError", +]