Important
TPU Raiden is currently under active development and is not yet recommended for general use. If you are interested in adopting this library, please reach out to the owners first to discuss compatibility, or proceed at your own risk.
You will need a python environment to run the JAX or torch code. Our code has been verified with python3.12. So the following should set you up properly:
cd
python3.12 -m venv .venv312
source .venv312/bin/activateTo compile the tpu_raiden C++ extension binaries, you will need Bazel 7.7.0.
Option 1: Install Bazel 7.7.0 directly (Linux amd64)
sudo wget -O /usr/local/bin/bazel https://github.com/bazelbuild/bazel/releases/download/7.7.0/bazel-7.7.0-linux-x86_64
sudo chmod +x /usr/local/bin/bazelOption 2: Install via Bazelisk (npm)
Bazelisk is a wrapper that will automatically read the .bazelversion file in the project and download the correct version (7.7.0).
npm install -g @bazel/bazeliskVerify the installation:
bazel --versionTo compile and link the PyTorch C++ extension (_tpu_raiden_torch.so), you MUST install patchelf:
sudo apt-get install -y patchelfWhy this is necessary: PyTorch's compiled extension requires patchelf to inject a NEEDED link on libpywrap_torch_tpu_common.so at build time. This ensures TPU backend symbols resolve locally during import without triggering fatal duplicate XLA allocator registration crashes.
- Disk Space: Remote Bazel builds on standard TPUVMs can exhaust disk space in
/tmp. Always point Bazel output to a directory that has enough disk space left.:export BAZEL_OUTPUT_BASE=$YOUR_TMP_DIR_WITH_ENOUGH_SPACE
- PyTorch Wheel Compatibility: Ensure your environment aligns with
torch_tpu's pinned C++ ABI expectations (e.g.,torch==2.11.0+cpu).
Note
The pre-built tpu_raiden wheel will be available on PyPI to public shortly.
If you are a Googler, you can install the pre-built tpu_raiden wheel directly from our Google Artifact Registry.
- Install the Artifact Registry keyring helper to enable authenticated pip downloads:
pip install keyrings.google-artifactregistry-auth
- Install the framework-specific wheel:
- For JAX version:
pip install tpu-raiden-jax --extra-index-url https://us-python.pkg.dev/cloud-tpu-inference-test/tpu-raiden/simple/
- For PyTorch version: Torch specific wheel will be published soon.
- For JAX version:
We provide a script to handle the build process and compile extension binaries locally. You can scope compilation to specific frameworks:
./build.sh [jax|torch|both]What this script does:
- Navigates to the workspace directory.
- Compiles the selected extension modules (
_tpu_raiden_jax.soand/or_tpu_raiden_torch.so) using Bazel. - For PyTorch builds, executes
patchelf --add-neededon the generated shared library. - Installs necessary Python dependencies listed in
requirements.txt. - Copies compiled
.soextension binaries directly into their respective framework source packages.
These are the core functional unit tests designed to verify the correctness of the foundational components and APIs. Once the build is complete, you can run the test suite across JAX and PyTorch:
./run_tests.sh [jax|torch|both]What this script does:
- Sets up
PYTHONPATHso Python can locate the compiledbazel-binand framework wrapper modules. - Executes the selected unit test suites across JAX and/or PyTorch directly via
python.
If you'd like to try out Raiden and see it in action, please refer to the examples/ directory. This folder contains a collection of hands-on scripts designed for users to interact with the library, including various testing scripts and performance microbenchmark scripts that demonstrate Raiden's capabilities.
For detailed instructions on how to run these examples and interpret their outputs, please check out the Examples README.