diff --git a/.tools/envs/testenv-linux.yml b/.tools/envs/testenv-linux.yml index f3d704a7f..12e5258cf 100644 --- a/.tools/envs/testenv-linux.yml +++ b/.tools/envs/testenv-linux.yml @@ -21,6 +21,7 @@ dependencies: - plotly>=6.2 # run, tests - matplotlib # tests - bokeh # tests + - altair # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-nevergrad.yml b/.tools/envs/testenv-nevergrad.yml index f008144dc..89a4a3e66 100644 --- a/.tools/envs/testenv-nevergrad.yml +++ b/.tools/envs/testenv-nevergrad.yml @@ -19,6 +19,7 @@ dependencies: - plotly>=6.2 # run, tests - matplotlib # tests - bokeh # tests + - altair # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-numpy.yml b/.tools/envs/testenv-numpy.yml index 34355b9f4..3c8493065 100644 --- a/.tools/envs/testenv-numpy.yml +++ b/.tools/envs/testenv-numpy.yml @@ -19,6 +19,7 @@ dependencies: - plotly>=6.2 # run, tests - matplotlib # tests - bokeh # tests + - altair # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-others.yml b/.tools/envs/testenv-others.yml index 3db398224..601ff9768 100644 --- a/.tools/envs/testenv-others.yml +++ b/.tools/envs/testenv-others.yml @@ -19,6 +19,7 @@ dependencies: - plotly>=6.2 # run, tests - matplotlib # tests - bokeh # tests + - altair # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-pandas.yml b/.tools/envs/testenv-pandas.yml index 4fbe7e512..bfbd06d5a 100644 --- a/.tools/envs/testenv-pandas.yml +++ b/.tools/envs/testenv-pandas.yml @@ -19,6 +19,7 @@ dependencies: - plotly>=6.2 # run, tests - matplotlib # tests - bokeh # tests + - altair # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-plotly.yml b/.tools/envs/testenv-plotly.yml index fec9ed3e7..4f80422ee 100644 --- a/.tools/envs/testenv-plotly.yml +++ b/.tools/envs/testenv-plotly.yml @@ -19,6 +19,7 @@ dependencies: - pandas # run, tests - matplotlib # tests - bokeh # tests + - altair # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/docs/rtd_environment.yml b/docs/rtd_environment.yml index 5ff33f75c..ef417f966 100644 --- a/docs/rtd_environment.yml +++ b/docs/rtd_environment.yml @@ -20,6 +20,8 @@ dependencies: - pybaum - matplotlib - bokeh + - altair + - anywidget - seaborn - numpy - pandas diff --git a/docs/source/how_to/how_to_change_plotting_backend.ipynb b/docs/source/how_to/how_to_change_plotting_backend.ipynb index c8e07237f..2d3b32f5a 100644 --- a/docs/source/how_to/how_to_change_plotting_backend.ipynb +++ b/docs/source/how_to/how_to_change_plotting_backend.ipynb @@ -57,6 +57,20 @@ "```{warning}\n", "Bokeh applies themes globally. Passing the `template` parameter to a plotting function updates the theme for all existing and future Bokeh plots. If you do not pass `template`, a default template is applied, which will also change the global theme.\n", "```\n", + "\n", + ":::\n", + "\n", + ":::{tab-item} Altair\n", + "To select the Altair backend, set `backend=\"altair\"`.\n", + "\n", + "The returned figure object is an [`altair.Chart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html).\n", + "\n", + "```{note}\n", + "In case of grid plots (such as `convergence_plot` or `slice_plot`), the returned object is either an [`altair.Chart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html) if there is only one subplot, an [`altair.HConcatChart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.HConcatChart.html) if there is only one row, or an [`altair.VConcatChart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.VConcatChart.html) otherwise.\n", + "```\n", + "\n", + ":::\n", + "\n", "::::" ] }, @@ -170,11 +184,51 @@ "p = om.criterion_plot(results, backend=\"bokeh\")\n", "show(p)" ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## Altair" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + ":::{note}\n", + "\n", + "It is mostly not required to set the renderer manually, as Altair automatically\n", + "selects the appropriate renderer based on your environment. In this example,\n", + "we explicitly set the renderer to ensure correct display within the documentation.\n", + "\n", + "Refer to the [Altair documentation](https://altair-viz.github.io/user_guide/display_frontends.html) for more details.\n", + "\n", + ":::" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "import altair as alt\n", + "\n", + "# Setting the renderer is mostly not required. See note above.\n", + "alt.renderers.enable(\"jupyter\")\n", + "\n", + "chart = om.criterion_plot(results, backend=\"altair\")\n", + "chart.show()" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "optimagic", "language": "python", "name": "python3" }, @@ -188,7 +242,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.18" + "version": "3.10.17" } }, "nbformat": 4, diff --git a/environment.yml b/environment.yml index da4a58840..797305556 100644 --- a/environment.yml +++ b/environment.yml @@ -23,6 +23,7 @@ dependencies: - plotly>=6.2 # run, tests - matplotlib # tests - bokeh # tests + - altair # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/src/optimagic/config.py b/src/optimagic/config.py index def8302f7..2f944bd9c 100644 --- a/src/optimagic/config.py +++ b/src/optimagic/config.py @@ -63,6 +63,7 @@ def _is_installed(module_name: str) -> bool: IS_MATPLOTLIB_INSTALLED = _is_installed("matplotlib") IS_BOKEH_INSTALLED = _is_installed("bokeh") +IS_ALTAIR_INSTALLED = _is_installed("altair") # ====================================================================================== # Check if pandas version is newer or equal to version 2.1.0 diff --git a/src/optimagic/visualization/backends.py b/src/optimagic/visualization/backends.py index c8195cc55..8299847a2 100644 --- a/src/optimagic/visualization/backends.py +++ b/src/optimagic/visualization/backends.py @@ -4,11 +4,16 @@ import numpy as np import plotly.graph_objects as go -from optimagic.config import IS_BOKEH_INSTALLED, IS_MATPLOTLIB_INSTALLED +from optimagic.config import ( + IS_ALTAIR_INSTALLED, + IS_BOKEH_INSTALLED, + IS_MATPLOTLIB_INSTALLED, +) from optimagic.exceptions import InvalidPlottingBackendError, NotInstalledError from optimagic.visualization.plotting_utilities import LineData, MarkerData if TYPE_CHECKING: + import altair as alt import bokeh import matplotlib.pyplot as plt @@ -592,9 +597,260 @@ def _grid_line_plot_bokeh( return grid +def _line_plot_altair( + lines: list[LineData], + *, + title: str | None, + xlabel: str | None, + xrange: tuple[float, float] | None, + ylabel: str | None, + yrange: tuple[float, float] | None, + template: str | None, + height: int | None, + width: int | None, + legend_properties: dict[str, Any] | None, + margin_properties: dict[str, Any] | None, + horizontal_line: float | None, + marker: MarkerData | None, + subplot: "alt.Chart | None" = None, +) -> "alt.Chart": + """Create a line plot using Altair. + + Args: + ...: All other argument descriptions can be found in the docstring of the + `line_plot` function. + subplot: An Altair Chart object to which the lines should be plotted. + If provided, the plot is drawn on the given Chart. If not provided, + a new Chart is created. + + Returns: + An Altair Chart object. + + """ + import altair as alt + import pandas as pd + + alt.data_transformers.disable_max_rows() + + if template is None: + template = "default" + alt.theme.enable("default") + + dfs = [] + for line in lines: + df = pd.DataFrame( + {"x": line.x, "y": line.y, "name": line.name, "color": line.color} + ) + dfs.append(df) + source = pd.concat(dfs) + + figure_properties: dict[str, str | int] = {} + if title is not None: + figure_properties["title"] = title + if width is not None: + figure_properties["width"] = width + if height is not None: + figure_properties["height"] = height + + chart = ( + alt.Chart(source) + .mark_line() + .encode( + x=alt.X( + "x", + title=xlabel.split("{linebreak}") if xlabel else None, + scale=alt.Scale(domain=list(xrange)) if xrange else alt.Undefined, + ), + y=alt.Y( + "y", + title=ylabel.split("{linebreak}") if ylabel else None, + scale=alt.Scale(domain=list(yrange)) if yrange else alt.Undefined, + ), + color=alt.Color("color:N", scale=None), + detail="name:N", + ) + .properties(**figure_properties) + ) + + if any(line.show_in_legend for line in lines): + legend = ( + alt.Chart(source) + .mark_line() + .encode( + color=alt.Color( + "name:N", + title=None, + legend=alt.Legend(**(legend_properties or {})), + scale=alt.Scale( + domain=[line.name for line in lines if line.show_in_legend], + range=[ + line.color or "" for line in lines if line.show_in_legend + ], + ), + ) + ) + ) + chart = chart + legend + + if horizontal_line is not None: + hline = ( + alt.Chart(pd.DataFrame({"y": [horizontal_line]})).mark_rule().encode(y="y") + ) + chart = chart + hline + + if marker is not None: + marker_chart = ( + alt.Chart(pd.DataFrame({"x": [marker.x], "y": [marker.y]})) + .mark_point(size=100, shape="circle", color=marker.color, filled=True) + .encode(x="x", y="y") + ) + chart = chart + marker_chart + + return chart.interactive() + + +def _grid_line_plot_altair( + lines_list: list[list[LineData]], + *, + n_rows: int, + n_cols: int, + titles: list[str] | None, + xlabels: list[str] | None, + xrange: tuple[float, float] | None, + share_x: bool, + ylabels: list[str] | None, + yrange: tuple[float, float] | None, + share_y: bool, + template: str | None, + height: int | None, + width: int | None, + legend_properties: dict[str, Any] | None, + margin_properties: dict[str, Any] | None, + plot_title: str | None, + marker_list: list[MarkerData] | None, + make_subplot_kwargs: dict[str, Any] | None = None, +) -> "alt.Chart | alt.HConcatChart | alt.VConcatChart": + """Create a grid of line plots using Altair. + + Args: + ...: All other argument descriptions can be found in the docstring of the + `grid_line_plot` function. + + Returns: + An Altair Chart if the grid contains only one subplot, an Altair HConcatChart + if 'n_rows' is 1, otherwise an Altair VConcatChart. + + """ + import altair as alt + + subplot_height = height // n_rows if height else None + subplot_width = width // n_cols if width else None + + charts = [] + for row_idx in range(n_rows): + chart_row = [] + for col_idx in range(n_cols): + i = row_idx * n_cols + col_idx + if i >= len(lines_list): + break + + chart = _line_plot_altair( + lines_list[i], + title=titles[i] if titles else None, + xlabel=xlabels[i] if xlabels else None, + xrange=xrange, + ylabel=ylabels[i] if ylabels else None, + yrange=yrange, + template=template, + height=subplot_height, + width=subplot_width, + legend_properties=legend_properties, + margin_properties=None, + horizontal_line=None, + marker=marker_list[i] if marker_list else None, + subplot=None, + ) + + chart_row.append(chart) + charts.append(chart_row) + + row_selections = [ + alt.selection_interval( + bind="scales", encodings=["y"], name=f"share_y_row{row_idx}" + ) + for row_idx in range(n_rows) + ] + col_selections = [ + alt.selection_interval( + bind="scales", encodings=["x"], name=f"share_x_col{col_idx}" + ) + for col_idx in range(n_cols) + ] + + for row_idx, row in enumerate(charts): + for col_idx in range(len(row)): + chart = row[col_idx] + + params = [] + if share_y: + # Share y-axis for all subplots in the same row + params.append(row_selections[row_idx]) + else: + # Use independent y-axes for each subplot + params.append( + alt.selection_interval( + bind="scales", + encodings=["y"], + name=f"ind_y_row{row_idx}_col{col_idx}", + ) + ) + if share_x: + # Share x-axis for all subplots in the same column + params.append(col_selections[col_idx]) + else: + # Use independent x-axes for each subplot + params.append( + alt.selection_interval( + bind="scales", + encodings=["x"], + name=f"ind_x_row{row_idx}_col{col_idx}", + ) + ) + chart = chart.add_params(*params) + + if share_y and col_idx > 0: + # Hide y-axis ticklabels for all subplots except the leftmost column + chart = chart.encode(y=alt.Y(axis=alt.Axis(labels=False))) + if share_x and row_idx < n_rows - 1: + # Hide x-axis ticklabels for all subplots except the bottom row + chart = chart.encode(x=alt.X(axis=alt.Axis(labels=False))) + + charts[row_idx][col_idx] = chart + + row_charts = [] + for row in charts: + row_chart: alt.Chart | alt.HConcatChart + if len(row) == 1: + row_chart = row[0] + else: + row_chart = alt.hconcat(*row) + row_charts.append(row_chart) + + grid_chart: alt.Chart | alt.HConcatChart | alt.VConcatChart + if len(row_charts) == 1: + grid_chart = row_charts[0] + else: + grid_chart = alt.vconcat(*row_charts) + + if plot_title is not None: + grid_chart = grid_chart.properties(title=plot_title) + + return grid_chart + + def line_plot( lines: list[LineData], - backend: Literal["plotly", "matplotlib", "bokeh"] = "plotly", + backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", *, title: str | None = None, xlabel: str | None = None, @@ -657,7 +913,7 @@ def line_plot( def grid_line_plot( lines_list: list[list[LineData]], - backend: Literal["plotly", "matplotlib", "bokeh"] = "plotly", + backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", *, n_rows: int, n_cols: int, @@ -750,18 +1006,25 @@ def grid_line_plot( _line_plot_bokeh, _grid_line_plot_bokeh, ), + "altair": ( + IS_ALTAIR_INSTALLED, + _line_plot_altair, + _grid_line_plot_altair, + ), } @overload def _get_plot_function( - backend: Literal["plotly", "matplotlib", "bokeh"], grid_plot: Literal[False] + backend: Literal["plotly", "matplotlib", "bokeh", "altair"], + grid_plot: Literal[False], ) -> LinePlotFunction: ... @overload def _get_plot_function( - backend: Literal["plotly", "matplotlib", "bokeh"], grid_plot: Literal[True] + backend: Literal["plotly", "matplotlib", "bokeh", "altair"], + grid_plot: Literal[True], ) -> GridLinePlotFunction: ... diff --git a/src/optimagic/visualization/convergence_plot.py b/src/optimagic/visualization/convergence_plot.py index 3287a9b4d..487749d11 100644 --- a/src/optimagic/visualization/convergence_plot.py +++ b/src/optimagic/visualization/convergence_plot.py @@ -19,6 +19,7 @@ "place": "right", "label_text_font_size": "8pt", }, + "altair": {"orient": "right"}, } BACKEND_TO_CONVERGENCE_PLOT_MARGIN_PROPERTIES: dict[str, dict[str, int]] = { @@ -75,7 +76,7 @@ def convergence_plot( x_precision: float = 1e-4, y_precision: float = 1e-4, combine_plots_in_grid: bool = True, - backend: Literal["plotly", "matplotlib", "bokeh"] = "plotly", + backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", template: str | None = None, palette: list[str] | str = DEFAULT_PALETTE, ) -> Any: diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index d14d8412b..14d684508 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -30,6 +30,9 @@ "bokeh": { "location": "top_right", }, + "altair": { + "orient": "top-right", + }, } @@ -40,7 +43,7 @@ def criterion_plot( results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath], names: list[str] | str | None = None, max_evaluations: int | None = None, - backend: Literal["plotly", "matplotlib", "bokeh"] = "plotly", + backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", template: str | None = None, palette: list[str] | str = DEFAULT_PALETTE, stack_multistart: bool = False, @@ -158,7 +161,7 @@ def params_plot( result: ResultOrPath, selector: Callable[[PyTree], PyTree] | None = None, max_evaluations: int | None = None, - backend: Literal["plotly", "matplotlib", "bokeh"] = "plotly", + backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", template: str | None = None, palette: list[str] | str = DEFAULT_PALETTE, show_exploration: bool = False, diff --git a/src/optimagic/visualization/profile_plot.py b/src/optimagic/visualization/profile_plot.py index cf9482aba..003382e5b 100644 --- a/src/optimagic/visualization/profile_plot.py +++ b/src/optimagic/visualization/profile_plot.py @@ -25,6 +25,7 @@ "label_text_font_size": "8pt", "title": "algorithm", }, + "altair": {"orient": "right", "title": "algorithm"}, } BACKEND_TO_PROFILE_PLOT_MARGIN_PROPERTIES: dict[str, dict[str, Any]] = { @@ -44,7 +45,7 @@ def profile_plot( stopping_criterion: Literal["x", "y", "x_and_y", "x_or_y"] = "y", x_precision: float = 1e-4, y_precision: float = 1e-4, - backend: Literal["plotly", "matplotlib", "bokeh"] = "plotly", + backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", template: str | None = None, palette: list[str] | str = DEFAULT_PALETTE, ) -> Any: diff --git a/src/optimagic/visualization/slice_plot.py b/src/optimagic/visualization/slice_plot.py index c630b7a1d..df98f363a 100644 --- a/src/optimagic/visualization/slice_plot.py +++ b/src/optimagic/visualization/slice_plot.py @@ -44,7 +44,7 @@ def slice_plot( share_y: bool = True, expand_yrange: float = 0.02, share_x: bool = False, - backend: Literal["plotly", "matplotlib", "bokeh"] = "plotly", + backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", template: str | None = None, color: str | None = DEFAULT_PALETTE[0], title: str | None = None, diff --git a/tests/optimagic/visualization/test_convergence_plot.py b/tests/optimagic/visualization/test_convergence_plot.py index 3a66131af..931786e30 100644 --- a/tests/optimagic/visualization/test_convergence_plot.py +++ b/tests/optimagic/visualization/test_convergence_plot.py @@ -54,6 +54,7 @@ def test_convergence_plot_default_options(benchmark_results): {"y_precision": 1e-5}, {"backend": "matplotlib"}, {"backend": "bokeh"}, + {"backend": "altair"}, ] diff --git a/tests/optimagic/visualization/test_profile_plot.py b/tests/optimagic/visualization/test_profile_plot.py index 5589c2e30..a5f482455 100644 --- a/tests/optimagic/visualization/test_profile_plot.py +++ b/tests/optimagic/visualization/test_profile_plot.py @@ -182,6 +182,7 @@ def test_extract_profile_plot_lines(): {"stopping_criterion": "x_or_y"}, {"backend": "matplotlib"}, {"backend": "bokeh"}, + {"backend": "altair"}, ]