Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Follow the below links for an approximate ordering of example tutorials from int
examples/feature_interpretations_example
examples/validation_with_tmle
examples/dragonnet_example
examples/dragonnet_jax_vs_tf
examples/iv_nlsym_synthetic_data
examples/sensitivity_example_with_synthetic_data
examples/counterfactual_unit_selection
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/dragonnet_jax_vs_tf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"id": "2613c834",
"metadata": {},
"source": [
"## Installation\n",
"## Setup\n",
"\n",
"**TF backend:**\n",
"```\n",
Expand Down
24 changes: 22 additions & 2 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Installation

Installation with ``conda`` or ``pip`` is recommended. Developers can follow the **Install from source** instructions below. If building from source, consider doing so within a conda environment and then exporting the environment for reproducibility.

To use models under the ``inference.tf`` or ``inference.torch`` module (e.g. ``DragonNet`` or ``CEVAE``), additional dependency of ``tensorflow`` or ``torch`` is required. For detailed instructions, see below.
To use models under the ``inference.tf``, ``inference.torch`` or ``inference.jax`` module (e.g. ``DragonNet`` or ``CEVAE``), additional dependency of ``tensorflow``, ``torch`` or ``jax`` is required. For detailed instructions, see below.

System Requirements
-------------------
Expand Down Expand Up @@ -67,6 +67,13 @@ Install ``causalml`` with ``torch`` for ``CEVAE`` from ``PyPI``

pip install causalml[torch]

Install ``causalml`` with ``jax`` for ``DragonNet`` from ``PyPI``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code-block:: bash

pip install causalml[jax]


Install using `uv <https://github.com/astral-sh/uv/blob/main/README.md>`_
---------------------
Expand All @@ -89,6 +96,13 @@ Install ``causalml`` with ``torch`` for ``CEVAE`` using `uv <https://github.com/
.. code-block:: bash

uv add "causalml[torch]"

Install ``causalml`` with ``jax`` for ``DragonNet`` using `uv <https://github.com/astral-sh/uv/blob/main/README.md>`_
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code-block:: bash

uv add "causalml[jax]"



Expand Down Expand Up @@ -126,6 +140,12 @@ with ``torch`` for ``CEVAE``:

pip install -e ".[torch]"

with ``jax`` for ``DragonNet``:

.. code-block:: bash

pip install -e ".[jax]"

=======

Windows
Expand All @@ -149,7 +169,7 @@ Run all tests with:

pytest -vs tests/ --cov causalml/

Add ``--runtf`` and/or ``--runtorch`` to run optional tensorflow/torch tests which will be skipped by default.
Add ``--runtf``, ``--runtorch`` and/or ``--runjax`` to run optional tensorflow/torch/jax tests which will be skipped by default.

You can also run tests via make:

Expand Down