This project has two primary goals. These two goals are to provide a way to use XLA:
- Directly in C++ without needing to build all of XLA
- In a CMake-based project
XLA is a C++ project, but the only way to use it in another project is to build all of XLA and have a project which is built using Bazel.
- Since we are not building XLA, we need to get it from somewhere. We need an "XLA release". For now, we use the XLA plugin that ships with
jax[cuda]. This plugin is a shared library. - We also need an API to interface with the shared library. We need to use the PJRT C API from OpenXLA.
This project works by loading the XLA shared library and wraps the C API with an easier to use C++ API which does not depend on anything external.
After cloning this repository, two steps are required before you are ready to build:
- Go grab the pjrt_c_api.h from OpenXLA. Place it in the root level of this project. We provide one, but it is not guaranteed to be kept up to date.
- Use get_compile_options.ipynb to reuse the JAX code which creates the compile options which are passed to PJRT. Ideally you should run this python code on the same machine which you intend to use this C++ library on. Copy these compile option bytes into the
compileOptionsDataarray in pjrt/client.cpp. - Pip install jax[cuda] in a python virtual environment in the current directory:
python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
Note: Sourcing this virtual environment sets
PJRT_PLUGIN_FULL_PATH_CONFIGwhich is required by CMake to find the PJRT plugin path as installed by JAX
To build:
cmake . -B build
cmake --build build
To run the examples in this repository, first generate the example programs in the StableHlo text format. PJRT's input format for a program which you'd like to compile is StableHLO.
There are a few examples:
scalar_add_1: A simple example of adding two scalars.vector_add_1: A simple example of adding two vectors.mnist: A more complex example of training a Convolutional Neural Network (CNN) on the MNIST dataset.
For the simple examples, you can generate the StableHLO like this:
python examples/scalar_add_1/myJax.py > scalar_add_1.stablehlo
Now, run the example with the generated StableHlo program:
./build/examples/scalar_add_1/scalar_add_1 scalar_add_1.stablehlo
For the MNIST example, please see the README in the mnist example directory.
Dive into the pjrt directory to see the design principals.
I plan to tackle items in the TODO list.