Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optika/distortion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

from ._distortion import (
AbstractDistortionModel,
AbstractLinearDistortionModel,
SimpleDistortionModel,
AbstractInterpolatedDistortionModel,
PolynomialDistortionModel,
)

__all__ = [
"AbstractDistortionModel",
"AbstractLinearDistortionModel",
"SimpleDistortionModel",
"AbstractInterpolatedDistortionModel",
"PolynomialDistortionModel",
]
173 changes: 173 additions & 0 deletions optika/distortion/_distortion.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import abc
import dataclasses
import functools
import numpy as np
import matplotlib.axes
import matplotlib.cm
import matplotlib.colors
import matplotlib.figure
import matplotlib.pyplot as plt
import astropy.units as u
import astropy.visualization
import named_arrays as na
import optika

__all__ = [
"AbstractDistortionModel",
"AbstractLinearDistortionModel",
"SimpleDistortionModel",
"AbstractInterpolatedDistortionModel",
"PolynomialDistortionModel",
]
Expand Down Expand Up @@ -61,6 +65,175 @@ def undistort(
"""


@dataclasses.dataclass(eq=False, repr=False)
class AbstractLinearDistortionModel(
AbstractDistortionModel,
):
r"""
A distortion model which is an affine transformation of the scene
coordinates,

.. math::

\text{distort}(\vec{c}) = \mathbf{M} \, (\vec{c} - \vec{c}_0) + \vec{b},

where :math:`\mathbf{M}` is :attr:`matrix`, :math:`\vec{c}_0` is
:attr:`center`, and :math:`\vec{b}` is :attr:`intercept`.
Since the transformation is linear, :meth:`undistort` is its *exact*
inverse (unlike a polynomial fit).
"""

@property
@abc.abstractmethod
def matrix(self) -> na.AbstractSpectralPositionalMatrixArray:
"""The linear part of the affine transformation."""

@property
@abc.abstractmethod
def center(self) -> na.AbstractSpectralPositionalVectorArray:
"""The reference point subtracted from the coordinates before
applying :attr:`matrix`."""

@property
@abc.abstractmethod
def intercept(self) -> na.AbstractSpectralPositionalVectorArray:
"""The constant offset added after applying :attr:`matrix`."""

def distort(
self,
coordinates: na.AbstractSpectralPositionalVectorArray,
) -> na.SpectralPositionalVectorArray:
return self.matrix @ (coordinates - self.center) + self.intercept

def undistort(
self,
coordinates: na.AbstractSpectralPositionalVectorArray,
) -> na.SpectralPositionalVectorArray:
return self.matrix.inverse @ (coordinates - self.intercept) + self.center


@dataclasses.dataclass(eq=False, repr=False)
class SimpleDistortionModel(
AbstractLinearDistortionModel,
):
r"""
A simple analytic distortion model consisting of a rotation of the field,
an isotropic plate scale, and a linear spectral dispersion along the
rotated :math:`x` axis.

This captures the distortion of an idealized spectrograph: the field
center at the :attr:`reference` wavelength maps to the :attr:`reference`
position on the sensor, and other wavelengths are displaced along the
dispersion direction.

Examples
--------

Distort a grid of scene coordinates and plot the result on the sensor,
colored by wavelength.

.. jupyter-execute::

import matplotlib.pyplot as plt
import astropy.units as u
import named_arrays as na
import optika

model = optika.distortion.SimpleDistortionModel(
plate_scale=1 * u.arcsec / u.pix,
dispersion=2 * u.nm / u.pix,
angle=15 * u.deg,
reference=na.SpectralPositionalVectorArray(
wavelength=550 * u.nm,
position=na.Cartesian2dVectorArray(0, 0) * u.pix,
),
)

scene = na.SpectralPositionalVectorArray(
wavelength=na.linspace(500, 600, axis="wavelength", num=3) * u.nm,
position=na.Cartesian2dVectorLinearSpace(
start=-10 * u.arcsec,
stop=+10 * u.arcsec,
axis=na.Cartesian2dVectorArray("field_x", "field_y"),
num=5,
),
)

sensor = model.distort(scene)

fig, ax = plt.subplots(constrained_layout=True)
ax.set_aspect("equal")
for wavelength in scene.wavelength.ndarray:
na.plt.scatter(
sensor.position.x,
sensor.position.y,
where=scene.wavelength == wavelength,
label=f"{wavelength}",
ax=ax,
)
ax.set_xlabel(f"detector $x$ ({na.unit(sensor.position.x):latex_inline})")
ax.set_ylabel(f"detector $y$ ({na.unit(sensor.position.y):latex_inline})")
ax.legend();
"""

plate_scale: u.Quantity | na.AbstractScalar = dataclasses.MISSING
"""The spatial plate scale, in units such as :math:`\\text{arcsec} / \\text{pix}`."""

dispersion: u.Quantity | na.AbstractScalar = dataclasses.MISSING
"""The magnitude of the spectral dispersion, in units such as :math:`\\text{nm} / \\text{pix}`."""

angle: u.Quantity | na.AbstractScalar = dataclasses.MISSING
"""The angle of the dispersion direction with respect to the scene."""

reference: na.AbstractSpectralPositionalVectorArray = dataclasses.MISSING
"""The reference wavelength and the sensor position that the field center
maps to at that wavelength."""

@functools.cached_property
def matrix(self) -> na.SpectralPositionalMatrixArray:
cos = np.cos(self.angle)
sin = np.sin(self.angle)
plate_scale = self.plate_scale
dispersion = self.dispersion
unit_wavelength = na.unit(self.reference.wavelength)
return na.SpectralPositionalMatrixArray(
wavelength=na.SpectralPositionalVectorArray(
wavelength=1,
position=na.Cartesian2dVectorArray(
x=0 * unit_wavelength / u.arcsec,
y=0 * unit_wavelength / u.arcsec,
),
),
position=na.Cartesian2dMatrixArray(
x=na.SpectralPositionalVectorArray(
wavelength=1 / dispersion,
position=na.Cartesian2dVectorArray(
x=cos / plate_scale,
y=-sin / plate_scale,
),
),
y=na.SpectralPositionalVectorArray(
wavelength=0 / dispersion,
position=na.Cartesian2dVectorArray(
x=sin / plate_scale,
y=cos / plate_scale,
),
),
),
)

@property
def center(self) -> na.SpectralPositionalVectorArray:
return na.SpectralPositionalVectorArray(
wavelength=self.reference.wavelength,
position=na.Cartesian2dVectorArray(0, 0) * u.arcsec,
)

@property
def intercept(self) -> na.AbstractSpectralPositionalVectorArray:
return self.reference


@dataclasses.dataclass(eq=False, repr=False)
class AbstractInterpolatedDistortionModel(
AbstractDistortionModel,
Expand Down
34 changes: 34 additions & 0 deletions optika/distortion/_distortion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,40 @@ def test_roundtrip(self, a: optika.distortion.AbstractDistortionModel):
assert np.all(error < 1e-9 * u.deg)


class AbstractTestAbstractLinearDistortionModel(
AbstractTestAbstractDistortionModel,
):
def test_matrix(self, a: optika.distortion.AbstractLinearDistortionModel):
assert isinstance(a.matrix, na.AbstractSpectralPositionalMatrixArray)

def test_center(self, a: optika.distortion.AbstractLinearDistortionModel):
assert isinstance(a.center, na.AbstractSpectralPositionalVectorArray)

def test_intercept(self, a: optika.distortion.AbstractLinearDistortionModel):
assert isinstance(a.intercept, na.AbstractSpectralPositionalVectorArray)


@pytest.mark.parametrize(
argnames="a",
argvalues=[
optika.distortion.SimpleDistortionModel(
plate_scale=1 * u.arcsec / u.pix,
dispersion=0.1 * u.nm / u.pix,
angle=angle,
reference=na.SpectralPositionalVectorArray(
wavelength=550 * u.nm,
position=na.Cartesian2dVectorArray(0, 0) * u.pix,
),
)
for angle in [0 * u.deg, 15 * u.deg]
],
)
class TestSimpleDistortionModel(
AbstractTestAbstractLinearDistortionModel,
):
pass


class AbstractTestAbstractInterpolatedDistortionModel(
AbstractTestAbstractDistortionModel,
):
Expand Down
Loading