From 90bee741fb432a17ad17c1091428fd3f4a61c27d Mon Sep 17 00:00:00 2001 From: xrhd Date: Sat, 27 Jun 2026 11:50:48 -0300 Subject: [PATCH] docs: document JAX backend for DragonNet Follow-up to #918 addressing the two documentation gaps flagged by @jeongyoonlee: - docs/examples.rst: add 'examples/dragonnet_jax_vs_tf' to the toctree so Sphinx renders the JAX vs TF benchmark notebook (previously omitted). - docs/installation.rst: document the 'jax' extra alongside 'tf'/'torch' (PyPI, uv, install-from-source) and mention 'inference.jax' in the intro; add '--runjax' to the Running Tests section. - docs/examples/dragonnet_jax_vs_tf.ipynb: rename the '## Installation' markdown heading to '## Setup' to avoid an autosectionlabel collision with installation.rst that the new toctree entry surfaces. Verified with a local Sphinx build (nbsphinx_execute=never): no new warnings introduced (870 vs 872 baseline) and the notebook renders correctly. --- docs/examples.rst | 1 + docs/examples/dragonnet_jax_vs_tf.ipynb | 2 +- docs/installation.rst | 24 ++++++++++++++++++++++-- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/docs/examples.rst b/docs/examples.rst index aa302c37..135f0204 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -15,6 +15,7 @@ Follow the below links for an approximate ordering of example tutorials from int examples/feature_interpretations_example examples/validation_with_tmle examples/dragonnet_example + examples/dragonnet_jax_vs_tf examples/iv_nlsym_synthetic_data examples/sensitivity_example_with_synthetic_data examples/counterfactual_unit_selection diff --git a/docs/examples/dragonnet_jax_vs_tf.ipynb b/docs/examples/dragonnet_jax_vs_tf.ipynb index 51e5d396..11f33b92 100644 --- a/docs/examples/dragonnet_jax_vs_tf.ipynb +++ b/docs/examples/dragonnet_jax_vs_tf.ipynb @@ -18,7 +18,7 @@ "id": "2613c834", "metadata": {}, "source": [ - "## Installation\n", + "## Setup\n", "\n", "**TF backend:**\n", "```\n", diff --git a/docs/installation.rst b/docs/installation.rst index 8a52737b..8718b7f3 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -4,7 +4,7 @@ Installation Installation with ``conda`` or ``pip`` is recommended. Developers can follow the **Install from source** instructions below. If building from source, consider doing so within a conda environment and then exporting the environment for reproducibility. -To use models under the ``inference.tf`` or ``inference.torch`` module (e.g. ``DragonNet`` or ``CEVAE``), additional dependency of ``tensorflow`` or ``torch`` is required. For detailed instructions, see below. +To use models under the ``inference.tf``, ``inference.torch`` or ``inference.jax`` module (e.g. ``DragonNet`` or ``CEVAE``), additional dependency of ``tensorflow``, ``torch`` or ``jax`` is required. For detailed instructions, see below. System Requirements ------------------- @@ -67,6 +67,13 @@ Install ``causalml`` with ``torch`` for ``CEVAE`` from ``PyPI`` pip install causalml[torch] +Install ``causalml`` with ``jax`` for ``DragonNet`` from ``PyPI`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + pip install causalml[jax] + Install using `uv `_ --------------------- @@ -89,6 +96,13 @@ Install ``causalml`` with ``torch`` for ``CEVAE`` using `uv `_ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + uv add "causalml[jax]" @@ -126,6 +140,12 @@ with ``torch`` for ``CEVAE``: pip install -e ".[torch]" +with ``jax`` for ``DragonNet``: + +.. code-block:: bash + + pip install -e ".[jax]" + ======= Windows @@ -149,7 +169,7 @@ Run all tests with: pytest -vs tests/ --cov causalml/ -Add ``--runtf`` and/or ``--runtorch`` to run optional tensorflow/torch tests which will be skipped by default. +Add ``--runtf``, ``--runtorch`` and/or ``--runjax`` to run optional tensorflow/torch/jax tests which will be skipped by default. You can also run tests via make: