diff --git a/docs/examples.rst b/docs/examples.rst index aa302c37..135f0204 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -15,6 +15,7 @@ Follow the below links for an approximate ordering of example tutorials from int examples/feature_interpretations_example examples/validation_with_tmle examples/dragonnet_example + examples/dragonnet_jax_vs_tf examples/iv_nlsym_synthetic_data examples/sensitivity_example_with_synthetic_data examples/counterfactual_unit_selection diff --git a/docs/examples/dragonnet_jax_vs_tf.ipynb b/docs/examples/dragonnet_jax_vs_tf.ipynb index 51e5d396..11f33b92 100644 --- a/docs/examples/dragonnet_jax_vs_tf.ipynb +++ b/docs/examples/dragonnet_jax_vs_tf.ipynb @@ -18,7 +18,7 @@ "id": "2613c834", "metadata": {}, "source": [ - "## Installation\n", + "## Setup\n", "\n", "**TF backend:**\n", "```\n", diff --git a/docs/installation.rst b/docs/installation.rst index 8a52737b..8718b7f3 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -4,7 +4,7 @@ Installation Installation with ``conda`` or ``pip`` is recommended. Developers can follow the **Install from source** instructions below. If building from source, consider doing so within a conda environment and then exporting the environment for reproducibility. -To use models under the ``inference.tf`` or ``inference.torch`` module (e.g. ``DragonNet`` or ``CEVAE``), additional dependency of ``tensorflow`` or ``torch`` is required. For detailed instructions, see below. +To use models under the ``inference.tf``, ``inference.torch`` or ``inference.jax`` module (e.g. ``DragonNet`` or ``CEVAE``), additional dependency of ``tensorflow``, ``torch`` or ``jax`` is required. For detailed instructions, see below. System Requirements ------------------- @@ -67,6 +67,13 @@ Install ``causalml`` with ``torch`` for ``CEVAE`` from ``PyPI`` pip install causalml[torch] +Install ``causalml`` with ``jax`` for ``DragonNet`` from ``PyPI`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + pip install causalml[jax] + Install using `uv `_ --------------------- @@ -89,6 +96,13 @@ Install ``causalml`` with ``torch`` for ``CEVAE`` using `uv `_ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + uv add "causalml[jax]" @@ -126,6 +140,12 @@ with ``torch`` for ``CEVAE``: pip install -e ".[torch]" +with ``jax`` for ``DragonNet``: + +.. code-block:: bash + + pip install -e ".[jax]" + ======= Windows @@ -149,7 +169,7 @@ Run all tests with: pytest -vs tests/ --cov causalml/ -Add ``--runtf`` and/or ``--runtorch`` to run optional tensorflow/torch tests which will be skipped by default. +Add ``--runtf``, ``--runtorch`` and/or ``--runjax`` to run optional tensorflow/torch/jax tests which will be skipped by default. You can also run tests via make: