diff --git a/optika/distortion/__init__.py b/optika/distortion/__init__.py index 01f41c1..9f6d39a 100644 --- a/optika/distortion/__init__.py +++ b/optika/distortion/__init__.py @@ -2,12 +2,16 @@ from ._distortion import ( AbstractDistortionModel, + AbstractLinearDistortionModel, + SimpleDistortionModel, AbstractInterpolatedDistortionModel, PolynomialDistortionModel, ) __all__ = [ "AbstractDistortionModel", + "AbstractLinearDistortionModel", + "SimpleDistortionModel", "AbstractInterpolatedDistortionModel", "PolynomialDistortionModel", ] diff --git a/optika/distortion/_distortion.py b/optika/distortion/_distortion.py index b67534e..ddd5197 100644 --- a/optika/distortion/_distortion.py +++ b/optika/distortion/_distortion.py @@ -1,17 +1,21 @@ import abc import dataclasses import functools +import numpy as np import matplotlib.axes import matplotlib.cm import matplotlib.colors import matplotlib.figure import matplotlib.pyplot as plt +import astropy.units as u import astropy.visualization import named_arrays as na import optika __all__ = [ "AbstractDistortionModel", + "AbstractLinearDistortionModel", + "SimpleDistortionModel", "AbstractInterpolatedDistortionModel", "PolynomialDistortionModel", ] @@ -61,6 +65,175 @@ def undistort( """ +@dataclasses.dataclass(eq=False, repr=False) +class AbstractLinearDistortionModel( + AbstractDistortionModel, +): + r""" + A distortion model which is an affine transformation of the scene + coordinates, + + .. math:: + + \text{distort}(\vec{c}) = \mathbf{M} \, (\vec{c} - \vec{c}_0) + \vec{b}, + + where :math:`\mathbf{M}` is :attr:`matrix`, :math:`\vec{c}_0` is + :attr:`center`, and :math:`\vec{b}` is :attr:`intercept`. + Since the transformation is linear, :meth:`undistort` is its *exact* + inverse (unlike a polynomial fit). + """ + + @property + @abc.abstractmethod + def matrix(self) -> na.AbstractSpectralPositionalMatrixArray: + """The linear part of the affine transformation.""" + + @property + @abc.abstractmethod + def center(self) -> na.AbstractSpectralPositionalVectorArray: + """The reference point subtracted from the coordinates before + applying :attr:`matrix`.""" + + @property + @abc.abstractmethod + def intercept(self) -> na.AbstractSpectralPositionalVectorArray: + """The constant offset added after applying :attr:`matrix`.""" + + def distort( + self, + coordinates: na.AbstractSpectralPositionalVectorArray, + ) -> na.SpectralPositionalVectorArray: + return self.matrix @ (coordinates - self.center) + self.intercept + + def undistort( + self, + coordinates: na.AbstractSpectralPositionalVectorArray, + ) -> na.SpectralPositionalVectorArray: + return self.matrix.inverse @ (coordinates - self.intercept) + self.center + + +@dataclasses.dataclass(eq=False, repr=False) +class SimpleDistortionModel( + AbstractLinearDistortionModel, +): + r""" + A simple analytic distortion model consisting of a rotation of the field, + an isotropic plate scale, and a linear spectral dispersion along the + rotated :math:`x` axis. + + This captures the distortion of an idealized spectrograph: the field + center at the :attr:`reference` wavelength maps to the :attr:`reference` + position on the sensor, and other wavelengths are displaced along the + dispersion direction. + + Examples + -------- + + Distort a grid of scene coordinates and plot the result on the sensor, + colored by wavelength. + + .. jupyter-execute:: + + import matplotlib.pyplot as plt + import astropy.units as u + import named_arrays as na + import optika + + model = optika.distortion.SimpleDistortionModel( + plate_scale=1 * u.arcsec / u.pix, + dispersion=2 * u.nm / u.pix, + angle=15 * u.deg, + reference=na.SpectralPositionalVectorArray( + wavelength=550 * u.nm, + position=na.Cartesian2dVectorArray(0, 0) * u.pix, + ), + ) + + scene = na.SpectralPositionalVectorArray( + wavelength=na.linspace(500, 600, axis="wavelength", num=3) * u.nm, + position=na.Cartesian2dVectorLinearSpace( + start=-10 * u.arcsec, + stop=+10 * u.arcsec, + axis=na.Cartesian2dVectorArray("field_x", "field_y"), + num=5, + ), + ) + + sensor = model.distort(scene) + + fig, ax = plt.subplots(constrained_layout=True) + ax.set_aspect("equal") + for wavelength in scene.wavelength.ndarray: + na.plt.scatter( + sensor.position.x, + sensor.position.y, + where=scene.wavelength == wavelength, + label=f"{wavelength}", + ax=ax, + ) + ax.set_xlabel(f"detector $x$ ({na.unit(sensor.position.x):latex_inline})") + ax.set_ylabel(f"detector $y$ ({na.unit(sensor.position.y):latex_inline})") + ax.legend(); + """ + + plate_scale: u.Quantity | na.AbstractScalar = dataclasses.MISSING + """The spatial plate scale, in units such as :math:`\\text{arcsec} / \\text{pix}`.""" + + dispersion: u.Quantity | na.AbstractScalar = dataclasses.MISSING + """The magnitude of the spectral dispersion, in units such as :math:`\\text{nm} / \\text{pix}`.""" + + angle: u.Quantity | na.AbstractScalar = dataclasses.MISSING + """The angle of the dispersion direction with respect to the scene.""" + + reference: na.AbstractSpectralPositionalVectorArray = dataclasses.MISSING + """The reference wavelength and the sensor position that the field center + maps to at that wavelength.""" + + @functools.cached_property + def matrix(self) -> na.SpectralPositionalMatrixArray: + cos = np.cos(self.angle) + sin = np.sin(self.angle) + plate_scale = self.plate_scale + dispersion = self.dispersion + unit_wavelength = na.unit(self.reference.wavelength) + return na.SpectralPositionalMatrixArray( + wavelength=na.SpectralPositionalVectorArray( + wavelength=1, + position=na.Cartesian2dVectorArray( + x=0 * unit_wavelength / u.arcsec, + y=0 * unit_wavelength / u.arcsec, + ), + ), + position=na.Cartesian2dMatrixArray( + x=na.SpectralPositionalVectorArray( + wavelength=1 / dispersion, + position=na.Cartesian2dVectorArray( + x=cos / plate_scale, + y=-sin / plate_scale, + ), + ), + y=na.SpectralPositionalVectorArray( + wavelength=0 / dispersion, + position=na.Cartesian2dVectorArray( + x=sin / plate_scale, + y=cos / plate_scale, + ), + ), + ), + ) + + @property + def center(self) -> na.SpectralPositionalVectorArray: + return na.SpectralPositionalVectorArray( + wavelength=self.reference.wavelength, + position=na.Cartesian2dVectorArray(0, 0) * u.arcsec, + ) + + @property + def intercept(self) -> na.AbstractSpectralPositionalVectorArray: + return self.reference + + @dataclasses.dataclass(eq=False, repr=False) class AbstractInterpolatedDistortionModel( AbstractDistortionModel, diff --git a/optika/distortion/_distortion_test.py b/optika/distortion/_distortion_test.py index f2a77d0..721b12f 100644 --- a/optika/distortion/_distortion_test.py +++ b/optika/distortion/_distortion_test.py @@ -44,6 +44,40 @@ def test_roundtrip(self, a: optika.distortion.AbstractDistortionModel): assert np.all(error < 1e-9 * u.deg) +class AbstractTestAbstractLinearDistortionModel( + AbstractTestAbstractDistortionModel, +): + def test_matrix(self, a: optika.distortion.AbstractLinearDistortionModel): + assert isinstance(a.matrix, na.AbstractSpectralPositionalMatrixArray) + + def test_center(self, a: optika.distortion.AbstractLinearDistortionModel): + assert isinstance(a.center, na.AbstractSpectralPositionalVectorArray) + + def test_intercept(self, a: optika.distortion.AbstractLinearDistortionModel): + assert isinstance(a.intercept, na.AbstractSpectralPositionalVectorArray) + + +@pytest.mark.parametrize( + argnames="a", + argvalues=[ + optika.distortion.SimpleDistortionModel( + plate_scale=1 * u.arcsec / u.pix, + dispersion=0.1 * u.nm / u.pix, + angle=angle, + reference=na.SpectralPositionalVectorArray( + wavelength=550 * u.nm, + position=na.Cartesian2dVectorArray(0, 0) * u.pix, + ), + ) + for angle in [0 * u.deg, 15 * u.deg] + ], +) +class TestSimpleDistortionModel( + AbstractTestAbstractLinearDistortionModel, +): + pass + + class AbstractTestAbstractInterpolatedDistortionModel( AbstractTestAbstractDistortionModel, ):