diff --git a/optika/__init__.py b/optika/__init__.py index bbb8ea4..699bdcf 100644 --- a/optika/__init__.py +++ b/optika/__init__.py @@ -18,6 +18,7 @@ from . import surfaces from . import sensors from . import distortion +from . import vignetting from . import systems __all__ = [ @@ -39,5 +40,6 @@ "surfaces", "sensors", "distortion", + "vignetting", "systems", ] diff --git a/optika/vignetting/__init__.py b/optika/vignetting/__init__.py new file mode 100644 index 0000000..14f6989 --- /dev/null +++ b/optika/vignetting/__init__.py @@ -0,0 +1,13 @@ +"""Model the vignetting of a scene observed by an optical system.""" + +from ._vignetting import ( + AbstractVignettingModel, + AbstractInterpolatedVignettingModel, + PolynomialVignettingModel, +) + +__all__ = [ + "AbstractVignettingModel", + "AbstractInterpolatedVignettingModel", + "PolynomialVignettingModel", +] diff --git a/optika/vignetting/_vignetting.py b/optika/vignetting/_vignetting.py new file mode 100644 index 0000000..074f9c2 --- /dev/null +++ b/optika/vignetting/_vignetting.py @@ -0,0 +1,347 @@ +import abc +import dataclasses +import functools +import matplotlib.axes +import matplotlib.cm +import matplotlib.colors +import matplotlib.figure +import matplotlib.pyplot as plt +import astropy.visualization +import named_arrays as na +import optika + +__all__ = [ + "AbstractVignettingModel", + "AbstractInterpolatedVignettingModel", + "PolynomialVignettingModel", +] + + +@dataclasses.dataclass(eq=False, repr=False) +class AbstractVignettingModel( + optika.mixins.Printable, +): + """ + An interface describing an arbitrary vignetting model, which maps scene + coordinates to the fraction of light transmitted by the optical system. + """ + + @abc.abstractmethod + def __call__( + self, + coordinates: na.AbstractSpectralPositionalVectorArray, + ) -> na.AbstractScalar: + """ + Compute the fraction of light transmitted for the given scene + coordinates. + + Parameters + ---------- + coordinates + The wavelength and position of each point in the scene. + """ + + def inverse( + self, + coordinates: na.AbstractSpectralPositionalVectorArray, + ) -> na.AbstractScalar: + r""" + Compute the inverse of the transmission, :math:`1 / T`, the factor + which corrects for the vignetting at the given scene coordinates. + + Parameters + ---------- + coordinates + The wavelength and position of each point in the scene. + """ + return 1 / self(coordinates) + + +@dataclasses.dataclass(eq=False, repr=False) +class AbstractInterpolatedVignettingModel( + AbstractVignettingModel, +): + """ + A vignetting model defined by interpolating between known scene coordinates + and their measured transmission. + + This class has two main members, :attr:`coordinates_scene` and + :attr:`transmission`, the calibration points between which subclasses + interpolate. + """ + + @property + @abc.abstractmethod + def coordinates_scene(self) -> na.AbstractSpectralPositionalVectorArray: + """ + The wavelength and position of each calibration point in the scene. + """ + + @property + @abc.abstractmethod + def transmission(self) -> na.AbstractScalar: + """ + The fraction of light transmitted at each calibration point. + """ + + @property + @abc.abstractmethod + def axis_wavelength(self) -> str: + """The logical axis corresponding to changing wavelength.""" + + @property + @abc.abstractmethod + def axis_field(self) -> tuple[str, str]: + """The logical axes corresponding to changing position in the scene.""" + + +@dataclasses.dataclass(eq=False, repr=False) +class PolynomialVignettingModel( + AbstractInterpolatedVignettingModel, +): + """ + A vignetting model which fits a polynomial to the measured transmission at + known scene coordinates. + + Examples + -------- + + Build a vignetting model with a radial transmission falloff fit by a + deliberately underfit (linear) polynomial, then plot the transmission and + the fit residual. + + .. jupyter-execute:: + + import numpy as np + import astropy.units as u + import named_arrays as na + import optika + + scene = na.SpectralPositionalVectorArray( + wavelength=na.linspace(500, 600, axis="wavelength", num=3) * u.nm, + position=na.Cartesian2dVectorLinearSpace( + start=-1 * u.deg, + stop=+1 * u.deg, + axis=na.Cartesian2dVectorArray("field_x", "field_y"), + num=13, + ), + ) + transmission = 1 - 0.1 * (scene.position.length / u.deg) ** 2 + + model = optika.vignetting.PolynomialVignettingModel( + coordinates_scene=scene, + transmission=transmission, + axis_wavelength="wavelength", + axis_field=("field_x", "field_y"), + degree=1, + ) + + fig, ax = model.plot() + na.plt.set_aspect("equal", ax=ax); + + fig, ax = model.plot_residual() + na.plt.set_aspect("equal", ax=ax); + """ + + coordinates_scene: na.AbstractSpectralPositionalVectorArray = dataclasses.MISSING + """The wavelength and position of each calibration point in the scene.""" + + transmission: na.AbstractScalar = dataclasses.MISSING + """The fraction of light transmitted at each calibration point.""" + + axis_wavelength: str = dataclasses.MISSING + """The logical axis corresponding to changing wavelength.""" + + axis_field: tuple[str, str] = dataclasses.MISSING + """The logical axes corresponding to changing position in the scene.""" + + degree: int = 1 + """The degree of the polynomial used to model the vignetting.""" + + where: bool | na.AbstractScalar = True + """A boolean mask selecting which calibration points to use for fitting.""" + + @property + def _axis_scene(self) -> tuple[str, ...]: + """The logical axes over which the calibration points are distributed.""" + return (self.axis_wavelength, *self.axis_field) + + @functools.cached_property + def fit(self) -> na.PolynomialFitFunctionArray: + """The polynomial fit mapping scene coordinates to transmission.""" + scene = self.coordinates_scene + return na.PolynomialFitFunctionArray( + inputs=scene, + outputs=self.transmission, + center=scene.mean(self._axis_scene), + degree=self.degree, + where_polynomial=self.where, + ) + + def __call__( + self, + coordinates: na.AbstractSpectralPositionalVectorArray, + ) -> na.AbstractScalar: + return self.fit(coordinates).outputs + + def plot_residual( + self, + figsize: None | tuple[float, float] = None, + cmap: None | str | matplotlib.colors.Colormap = None, + vmin: None | na.ArrayLike = None, + vmax: None | na.ArrayLike = None, + **kwargs, + ) -> tuple[matplotlib.figure.Figure, na.ScalarArray]: + """ + Plot the residual of the :attr:`fit` as a function of field angle, with + a separate subplot for each wavelength. + + The residual is the absolute difference between the calibration + :attr:`transmission` and the transmission predicted by the polynomial + fit. + + Parameters + ---------- + figsize + The size of the returned figure in inches. + If :obj:`None`, the size is chosen automatically from the number + of wavelengths and the aspect ratio of the field of view. + cmap + The colormap used to map the residual to colors. + vmin + The residual value mapped to the lowest color. + If :obj:`None`, defaults to zero. + vmax + The residual value mapped to the highest color. + If :obj:`None`, defaults to the maximum residual. + kwargs + Additional keyword arguments passed to + :func:`named_arrays.plt.pcolormesh`. + """ + return self._plot( + abs(self.transmission - self.fit.predictions), + label="transmission residual", + figsize=figsize, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **kwargs, + ) + + def plot( + self, + figsize: None | tuple[float, float] = None, + cmap: None | str | matplotlib.colors.Colormap = None, + vmin: None | na.ArrayLike = None, + vmax: None | na.ArrayLike = None, + **kwargs, + ) -> tuple[matplotlib.figure.Figure, na.ScalarArray]: + """ + Plot the calibration :attr:`transmission` as a function of field angle, + with a separate subplot for each wavelength. + + Parameters + ---------- + figsize + The size of the returned figure in inches. + If :obj:`None`, the size is chosen automatically from the number + of wavelengths and the aspect ratio of the field of view. + cmap + The colormap used to map the transmission to colors. + vmin + The transmission value mapped to the lowest color. + If :obj:`None`, defaults to zero. + vmax + The transmission value mapped to the highest color. + If :obj:`None`, defaults to the maximum transmission. + kwargs + Additional keyword arguments passed to + :func:`named_arrays.plt.pcolormesh`. + """ + return self._plot( + self.transmission, + label="transmission", + figsize=figsize, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **kwargs, + ) + + def _plot( + self, + values: na.AbstractScalar, + label: str, + figsize: None | tuple[float, float] = None, + cmap: None | str | matplotlib.colors.Colormap = None, + vmin: None | na.ArrayLike = None, + vmax: None | na.ArrayLike = None, + **kwargs, + ) -> tuple[matplotlib.figure.Figure, na.ScalarArray]: + """ + Plot a scalar quantity as a function of field angle, with a separate + subplot for each wavelength. + """ + scene = self.coordinates_scene + position = scene.position + wavelength = na.as_named_array(scene.wavelength) + axis_wavelength = self.axis_wavelength + + if vmin is None: + vmin = 0 + if vmax is None: + vmax = values.max() + + ncols = na.shape(wavelength).get(axis_wavelength, 1) + + if figsize is None: + # shape each subplot to the field-of-view aspect ratio, and widen + # the figure to fit one subplot per wavelength + height_subplot = 3 + aspect = (position.x.ptp() / position.y.ptp()).ndarray.value + figsize = ( + ncols * height_subplot * aspect + 1.5, + height_subplot + 1, + ) + + with astropy.visualization.quantity_support(): + fig, ax = na.plt.subplots( + axis_cols=axis_wavelength, + ncols=ncols, + sharex=True, + sharey=True, + squeeze=False, + figsize=figsize, + constrained_layout=True, + ) + + colorizer = plt.Colorizer( + cmap=cmap, + norm=plt.Normalize( + vmin=na.as_named_array(vmin).ndarray, + vmax=na.as_named_array(vmax).ndarray, + ), + ) + + na.plt.pcolormesh( + position, + C=values, + ax=ax, + colorizer=colorizer, + **kwargs, + ) + + na.plt.set_xlabel(f"field $x$ ({na.unit(position.x):latex_inline})", ax=ax) + na.plt.set_ylabel( + f"field $y$ ({na.unit(position.y):latex_inline})", + ax=ax[{axis_wavelength: 0}], + ) + na.plt.set_title(wavelength.to_string_array(), ax=ax) + + plt.colorbar( + mappable=matplotlib.cm.ScalarMappable(colorizer=colorizer), + ax=ax.ndarray, + label=label, + ) + + return fig, ax diff --git a/optika/vignetting/_vignetting_test.py b/optika/vignetting/_vignetting_test.py new file mode 100644 index 0000000..d6618db --- /dev/null +++ b/optika/vignetting/_vignetting_test.py @@ -0,0 +1,127 @@ +import pytest +import numpy as np +import astropy.units as u +import matplotlib.figure +import matplotlib.pyplot as plt +import named_arrays as na +import optika +from .._tests import test_mixins + + +def _scene() -> na.SpectralPositionalVectorArray: + return na.SpectralPositionalVectorArray( + wavelength=na.linspace(500, 600, axis="wavelength", num=3) * u.nm, + position=na.Cartesian2dVectorLinearSpace( + start=-1 * u.deg, + stop=+1 * u.deg, + axis=na.Cartesian2dVectorArray("field_x", "field_y"), + num=5, + ), + ) + + +def _transmission() -> na.AbstractScalar: + return 1 - 0.1 * (_scene().position.length / u.deg) ** 2 + + +class AbstractTestAbstractVignettingModel( + test_mixins.AbstractTestPrintable, +): + def test__call__(self, a: optika.vignetting.AbstractVignettingModel): + scene = _scene() + result = a(scene) + assert isinstance(result, na.AbstractScalar) + for ax in ("field_x", "field_y"): + assert ax in na.shape(result) + + def test_inverse(self, a: optika.vignetting.AbstractVignettingModel): + scene = _scene() + result = a.inverse(scene) + assert isinstance(result, na.AbstractScalar) + assert np.all(result == 1 / a(scene)) + + +class AbstractTestAbstractInterpolatedVignettingModel( + AbstractTestAbstractVignettingModel, +): + def test_coordinates_scene( + self, + a: optika.vignetting.AbstractInterpolatedVignettingModel, + ): + assert isinstance(a.coordinates_scene, na.AbstractSpectralPositionalVectorArray) + + def test_transmission( + self, + a: optika.vignetting.AbstractInterpolatedVignettingModel, + ): + assert isinstance(a.transmission, na.AbstractScalar) + + def test_axis_wavelength( + self, + a: optika.vignetting.AbstractInterpolatedVignettingModel, + ): + assert isinstance(a.axis_wavelength, str) + + def test_axis_field( + self, + a: optika.vignetting.AbstractInterpolatedVignettingModel, + ): + assert isinstance(a.axis_field, tuple) + assert all(isinstance(ax, str) for ax in a.axis_field) + + +@pytest.mark.parametrize( + argnames="a", + argvalues=[ + optika.vignetting.PolynomialVignettingModel( + coordinates_scene=_scene(), + transmission=_transmission(), + axis_wavelength="wavelength", + axis_field=("field_x", "field_y"), + degree=degree, + ) + for degree in [1, 2] + ], +) +class TestPolynomialVignettingModel( + AbstractTestAbstractInterpolatedVignettingModel, +): + def test_fit(self, a: optika.vignetting.PolynomialVignettingModel): + assert isinstance(a.fit, na.PolynomialFitFunctionArray) + assert a.fit.degree == a.degree + + @pytest.mark.parametrize( + argnames="kwargs", + argvalues=[ + dict(), + dict(figsize=(8, 4), cmap="viridis", vmin=0, vmax=0.01), + ], + ) + def test_plot( + self, + a: optika.vignetting.PolynomialVignettingModel, + kwargs: dict, + ): + fig, ax = a.plot(**kwargs) + assert isinstance(fig, matplotlib.figure.Figure) + assert isinstance(ax, na.ScalarArray) + assert a.axis_wavelength in na.shape(ax) + plt.close(fig) + + @pytest.mark.parametrize( + argnames="kwargs", + argvalues=[ + dict(), + dict(figsize=(8, 4), cmap="viridis", vmin=0, vmax=0.01), + ], + ) + def test_plot_residual( + self, + a: optika.vignetting.PolynomialVignettingModel, + kwargs: dict, + ): + fig, ax = a.plot_residual(**kwargs) + assert isinstance(fig, matplotlib.figure.Figure) + assert isinstance(ax, na.ScalarArray) + assert a.axis_wavelength in na.shape(ax) + plt.close(fig)